org.apache.spark.sql.Row Scala Examples

The following examples show how to use org.apache.spark.sql.Row. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TextFileFormat.scala    From drizzle-spark   with Apache License 2.0 12 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration


  def getCompressionExtension(context: TaskAttemptContext): String = {
    // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
    if (FileOutputFormat.getCompressOutput(context)) {
      val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
      ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
    } else {
      ""
    }
  }
} 
Example 2
Source File: DataFrameExample.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import java.io.File

import scopt.OptionParser

import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.Utils


object DataFrameExample {

  case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
    extends AbstractParams[Params]

  def main(args: Array[String]) {
    val defaultParams = Params()

    val parser = new OptionParser[Params]("DataFrameExample") {
      head("DataFrameExample: an example app using DataFrame for ML.")
      opt[String]("input")
        .text(s"input path to dataframe")
        .action((x, c) => c.copy(input = x))
      checkConfig { params =>
        success
      }
    }

    parser.parse(args, defaultParams) match {
      case Some(params) => run(params)
      case _ => sys.exit(1)
    }
  }

  def run(params: Params): Unit = {
    val spark = SparkSession
      .builder
      .appName(s"DataFrameExample with $params")
      .getOrCreate()

    // Load input data
    println(s"Loading LIBSVM file with UDT from ${params.input}.")
    val df: DataFrame = spark.read.format("libsvm").load(params.input).cache()
    println("Schema from LIBSVM:")
    df.printSchema()
    println(s"Loaded training data as a DataFrame with ${df.count()} records.")

    // Show statistical summary of labels.
    val labelSummary = df.describe("label")
    labelSummary.show()

    // Convert features column to an RDD of vectors.
    val features = df.select("features").rdd.map { case Row(v: Vector) => v }
    val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
      (summary, feat) => summary.add(Vectors.fromML(feat)),
      (sum1, sum2) => sum1.merge(sum2))
    println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

    // Save the records in a parquet file.
    val tmpDir = Utils.createTempDir()
    val outputDir = new File(tmpDir, "dataframe").toString
    println(s"Saving to $outputDir as Parquet file.")
    df.write.parquet(outputDir)

    // Load the records back.
    println(s"Loading Parquet file with UDT from $outputDir.")
    val newDF = spark.read.parquet(outputDir)
    println(s"Schema from Parquet:")
    newDF.printSchema()

    spark.stop()
  }
}
// scalastyle:on println 
Example 3
Source File: TsStreamingTest.scala    From spark-riak-connector   with Apache License 2.0 7 votes vote down vote up
package com.basho.riak.spark.streaming

import java.nio.ByteBuffer
import java.util.concurrent.{Callable, Executors, TimeUnit}

import com.basho.riak.spark._
import com.basho.riak.spark.rdd.RiakTSTests
import com.basho.riak.spark.rdd.timeseries.{AbstractTimeSeriesTest, TimeSeriesData}
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature}
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.spark.sql.Row
import org.junit.Assert._
import org.junit.experimental.categories.Category
import org.junit.{After, Before, Test}

@Category(Array(classOf[RiakTSTests]))
class TsStreamingTest extends AbstractTimeSeriesTest(false) with SparkStreamingFixture {

  protected final val executorService = Executors.newCachedThreadPool()
  private val dataSource = new SocketStreamingDataSource
  private var port = -1

  @Before
  def setUp(): Unit = {
    port = dataSource.start(client => {
      testData
        .map(tolerantMapper.writeValueAsString)
        .foreach(x => client.write(ByteBuffer.wrap(s"$x\n".getBytes)))
      logInfo(s"${testData.length} values were send to client")
    })
  }

  @After
  def tearDown(): Unit = {
    dataSource.stop()
  }

  @Test(timeout = 10 * 1000) // 10 seconds timeout
  def saveToRiak(): Unit = {
    executorService.submit(new Runnable {
      override def run(): Unit = {
        ssc.socketTextStream("localhost", port)
          .map(string => {
            val tsdata = new ObjectMapper()
              .configure(DeserializationFeature.FAIL_ON_NULL_FOR_PRIMITIVES, true)
              .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true)
              .configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, true)
              .configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, true)
              .configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false)
              .registerModule(DefaultScalaModule)
              .readValue(string, classOf[TimeSeriesData])
            Row(1, "f", tsdata.time, tsdata.user_id, tsdata.temperature_k)
          })
          .saveToRiakTS(bucketName)

        ssc.start()
        ssc.awaitTerminationOrTimeout(5 * 1000)
      }
    })

    val result = executorService.submit(new Callable[Array[Seq[Any]]] {
      override def call(): Array[Seq[Any]] = {
        var rdd = sc.riakTSTable[Row](bucketName)
          .sql(s"SELECT user_id, temperature_k FROM $bucketName $sqlWhereClause")
        var count = rdd.count()
        while (count < testData.length) {
          TimeUnit.SECONDS.sleep(2)

          rdd = sc.riakTSTable[Row](bucketName)
            .sql(s"SELECT user_id, temperature_k FROM $bucketName $sqlWhereClause")
          count = rdd.count()
        }
        rdd.collect().map(_.toSeq)
      }
    }).get()

    assertEquals(testData.length, result.length)
    assertEqualsUsingJSONIgnoreOrder(
      """
        |[
        |   ['bryce',305.37],
        |   ['bryce',300.12],
        |   ['bryce',295.95],
        |   ['ratman',362.121],
        |   ['ratman',3502.212]
        |]
      """.stripMargin, result)
  }
} 
Example 4
Source File: OnErrorSuite.scala    From spark-snowflake   with Apache License 2.0 6 votes vote down vote up
package net.snowflake.spark.snowflake

import net.snowflake.client.jdbc.SnowflakeSQLException
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class OnErrorSuite extends IntegrationSuiteBase {
  lazy val table = s"spark_test_table_$randomSuffix"

  lazy val schema = new StructType(
    Array(StructField("var", StringType, nullable = false))
  )

  lazy val df: DataFrame = sparkSession.createDataFrame(
    sc.parallelize(
      Seq(Row("{\"dsadas\nadsa\":12311}"), Row("{\"abc\":334}")) // invalid json key
    ),
    schema
  )

  override def beforeAll(): Unit = {
    super.beforeAll()
    jdbcUpdate(s"create or replace table $table(var variant)")
  }

  override def afterAll(): Unit = {
    jdbcUpdate(s"drop table $table")
    super.afterAll()
  }

  test("continue_on_error off") {

    assertThrows[SnowflakeSQLException] {
      df.write
        .format(SNOWFLAKE_SOURCE_NAME)
        .options(connectorOptionsNoTable)
        .option("dbtable", table)
        .mode(SaveMode.Append)
        .save()
    }
  }

  test("continue_on_error on") {
    df.write
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("continue_on_error", "on")
      .option("dbtable", table)
      .mode(SaveMode.Append)
      .save()

    val result = sparkSession.read
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("dbtable", table)
      .load()

    assert(result.collect().length == 1)
  }

} 
Example 5
Source File: SqlUnitTest.scala    From SparkUnitTestingExamples   with Apache License 2.0 6 votes vote down vote up
package com.cloudera.sa.spark.unittest.sql

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

import scala.collection.mutable

class SqlUnitTest extends FunSuite with
BeforeAndAfterEach with BeforeAndAfterAll{

  @transient var sc: SparkContext = null
  @transient var hiveContext: HiveContext = null

  override def beforeAll(): Unit = {

    val envMap = Map[String,String](("Xmx", "512m"))

    val sparkConfig = new SparkConf()
    sparkConfig.set("spark.broadcast.compress", "false")
    sparkConfig.set("spark.shuffle.compress", "false")
    sparkConfig.set("spark.shuffle.spill.compress", "false")
    sparkConfig.set("spark.io.compression.codec", "lzf")
    sc = new SparkContext("local[2]", "unit test", sparkConfig)
    hiveContext = new HiveContext(sc)
  }

  override def afterAll(): Unit = {
    sc.stop()
  }

  test("Test table creation and summing of counts") {
    val personRDD = sc.parallelize(Seq(Row("ted", 42, "blue"),
      Row("tj", 11, "green"),
      Row("andrew", 9, "green")))

    hiveContext.sql("create table person (name string, age int, color string)")

    val emptyDataFrame = hiveContext.sql("select * from person limit 0")

    val personDataFrame = hiveContext.createDataFrame(personRDD, emptyDataFrame.schema)
    personDataFrame.registerTempTable("tempPerson")

    val ageSumDataFrame = hiveContext.sql("select sum(age) from tempPerson")

    val localAgeSum = ageSumDataFrame.take(10)

    assert(localAgeSum(0).get(0) == 62, "The sum of age should equal 62 but it equaled " + localAgeSum(0).get(0))
  }
} 
Example 6
Source File: SparkPFASuiteBase.scala    From aardpfark   with Apache License 2.0 6 votes vote down vote up
package com.ibm.aardpfark.pfa

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.SparkConf
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.scalactic.Equality
import org.scalatest.FunSuite

abstract class SparkPFASuiteBase extends FunSuite with DataFrameSuiteBase with PFATestUtils {

  val sparkTransformer: Transformer
  val input: Array[String]
  val expectedOutput: Array[String]

  val sparkConf =  new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID).
    set("spark.driver.host", "localhost")
  override lazy val spark = SparkSession.builder().config(sparkConf).getOrCreate()
  override val reuseContextIfPossible = true

  // Converts column containing a vector to an array
  def withColumnAsArray(df: DataFrame, colName: String) = {
    val vecToArray = udf { v: Vector => v.toArray }
    df.withColumn(colName, vecToArray(df(colName)))
  }

  def withColumnAsArray(df: DataFrame, first: String, others: String*) = {
    val vecToArray = udf { v: Vector => v.toArray }
    var result = df.withColumn(first, vecToArray(df(first)))
    others.foreach(c => result = result.withColumn(c, vecToArray(df(c))))
    result
  }

  // Converts column containing a vector to a sparse vector represented as a map
  def getColumnAsSparseVectorMap(df: DataFrame, colName: String) = {
    val vecToMap = udf { v: Vector => v.toSparse.indices.map(i => (i.toString, v(i))).toMap }
    df.withColumn(colName, vecToMap(df(colName)))
  }

}

abstract class Result

object ApproxEquality extends ApproxEquality

trait ApproxEquality {

  import org.scalactic.Tolerance._
  import org.scalactic.TripleEquals._

  implicit val seqApproxEq: Equality[Seq[Double]] = new Equality[Seq[Double]] {
    override def areEqual(a: Seq[Double], b: Any): Boolean = {
      b match {
        case d: Seq[Double] =>
          a.zip(d).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }

  implicit val vectorApproxEq: Equality[Vector] = new Equality[Vector] {
    override def areEqual(a: Vector, b: Any): Boolean = {
      b match {
        case v: Vector =>
          a.toArray.zip(v.toArray).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }
} 
Example 7
Source File: DNSstat.scala    From jdbcsink   with Apache License 2.0 6 votes vote down vote up
import org.apache.spark.sql.SparkSession
import java.util.Properties
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.{from_json,window}
import java.sql.{Connection,Statement,DriverManager}
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.Row

class JDBCSink() extends ForeachWriter[Row]{
 val driver = "com.mysql.jdbc.Driver"
      var connection:Connection = _
      var statement:Statement = _

    def open(partitionId: Long,version: Long): Boolean = {
        Class.forName(driver)
        connection = DriverManager.getConnection("jdbc:mysql://10.88.1.102:3306/aptwebservice", "root", "mysqladmin")
        statement = connection.createStatement
        true
      }
      def process(value: Row): Unit = {
        statement.executeUpdate("replace into DNSStat(ip,domain,time,count) values(" 
                                    + "'" + value.getString(0) + "'" + ","//ip
                                    + "'" + value.getString(1) + "'" + ","//domain
                                    + "'" + value.getTimestamp(2) + "'" + "," //time
                                    + value.getLong(3) //count
                                    + ")") 
      }

      def close(errorOrNull: Throwable): Unit = {
        connection.close
      }
}

object DNSstatJob{

val schema: StructType = StructType(
        Seq(StructField("Vendor", StringType,true),
         StructField("Id", IntegerType,true),
         StructField("Time", LongType,true),
         StructField("Conn", StructType(Seq(
                                        StructField("Proto", IntegerType, true), 
                                        StructField("Sport", IntegerType, true), 
                                        StructField("Dport", IntegerType, true), 
                                        StructField("Sip", StringType, true), 
                                        StructField("Dip", StringType, true)
                                        )), true),
        StructField("Dns", StructType(Seq(
                                        StructField("Domain", StringType, true), 
                                        StructField("IpCount", IntegerType, true), 
                                        StructField("Ip", StringType, true) 
                                        )), true)))

    def main(args: Array[String]) {
    val spark=SparkSession
          .builder
          .appName("DNSJob")
          .config("spark.some.config.option", "some-value")
          .getOrCreate()
    import spark.implicits._
    val connectionProperties = new Properties()
    connectionProperties.put("user", "root")
    connectionProperties.put("password", "mysqladmin")
    val bruteForceTab = spark.read
                .jdbc("jdbc:mysql://10.88.1.102:3306/aptwebservice", "DNSTab",connectionProperties)
    bruteForceTab.registerTempTable("DNSTab")
    val lines = spark
          .readStream
          .format("kafka")
          .option("kafka.bootstrap.servers", "10.94.1.110:9092")
          .option("subscribe","xdr")
          //.option("startingOffsets","earliest")
          .option("startingOffsets","latest")
          .load()
          .select(from_json($"value".cast(StringType),schema).as("jsonData"))
    lines.registerTempTable("xdr")
    val filterDNS = spark.sql("select CAST(from_unixtime(xdr.jsonData.Time DIV 1000000) as timestamp) as time,xdr.jsonData.Conn.Sip as sip, xdr.jsonData.Dns.Domain from xdr inner join DNSTab on xdr.jsonData.Dns.domain = DNSTab.domain")
    
    val windowedCounts = filterDNS
                        .withWatermark("time","5 minutes")
                        .groupBy(window($"time", "1 minutes", "1 minutes"),$"sip",$"domain")
                        .count()
                        .select($"sip",$"domain",$"window.start",$"count")

    val writer = new JDBCSink()
    val query = windowedCounts
       .writeStream
        .foreach(writer)
        .outputMode("update")
        .option("checkpointLocation","/checkpoint/")
        .start()
        query.awaitTermination() 
   } 
} 
Example 8
Source File: GLMRegressionModel.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression.impl

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.sql.{Row, SparkSession}


    def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
      val dataPath = Loader.dataPath(path)
      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
      val dataRDD = spark.read.parquet(dataPath)
      val dataArray = dataRDD.select("weights", "intercept").take(1)
      assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
      val data = dataArray(0)
      assert(data.size == 2, s"Unable to load $modelClass data from: $dataPath")
      data match {
        case Row(weights: Vector, intercept: Double) =>
          assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
            s" found ${weights.size} features when loading $modelClass weights from $dataPath")
          Data(weights, intercept)
      }
    }
  }

} 
Example 9
Source File: MaxAbsScalerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row

class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("MaxAbsScaler fit basic case") {
    val data = Array(
      Vectors.dense(1, 0, 100),
      Vectors.dense(2, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(-2, -100)),
      Vectors.sparse(3, Array(0), Array(-1.5)))

    val expected: Array[Vector] = Array(
      Vectors.dense(0.5, 0, 1),
      Vectors.dense(1, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(-1, -1)),
      Vectors.sparse(3, Array(0), Array(-0.75)))

    val df = data.zip(expected).toSeq.toDF("features", "expected")
    val scaler = new MaxAbsScaler()
      .setInputCol("features")
      .setOutputCol("scaled")

    val model = scaler.fit(df)
    testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
      case Row(expectedVec: Vector, actualVec: Vector) =>
        assert(expectedVec === actualVec,
          s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec")
    }

    MLTestingUtils.checkCopyAndUids(scaler, model)
  }

  test("MaxAbsScaler read/write") {
    val t = new MaxAbsScaler()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
    testDefaultReadWrite(t)
  }

  test("MaxAbsScalerModel read/write") {
    val instance = new MaxAbsScalerModel(
      "myMaxAbsScalerModel", Vectors.dense(1.0, 10.0))
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.maxAbs === instance.maxAbs)
  }

} 
Example 10
Source File: DCTSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.Row

@BeanInfo
case class DCTTestData(vec: Vector, wantedVec: Vector)

class DCTSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("forward transform of discrete cosine matches jTransforms result") {
    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    val inverse = false

    testDCT(data, inverse)
  }

  test("inverse transform of discrete cosine matches jTransforms result") {
    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    val inverse = true

    testDCT(data, inverse)
  }

  test("read/write") {
    val t = new DCT()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setInverse(true)
    testDefaultReadWrite(t)
  }

  private def testDCT(data: Vector, inverse: Boolean): Unit = {
    val expectedResultBuffer = data.toArray.clone()
    if (inverse) {
      new DoubleDCT_1D(data.size).inverse(expectedResultBuffer, true)
    } else {
      new DoubleDCT_1D(data.size).forward(expectedResultBuffer, true)
    }
    val expectedResult = Vectors.dense(expectedResultBuffer)

    val dataset = Seq(DCTTestData(data, expectedResult)).toDF()

    val transformer = new DCT()
      .setInputCol("vec")
      .setOutputCol("resultVec")
      .setInverse(inverse)

    testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") {
      case Row(resultVec: Vector, wantedVec: Vector) =>
        assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
    }
  }
} 
Example 11
Source File: ElementwiseProductSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.Row

class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("streaming transform") {
    val scalingVec = Vectors.dense(0.1, 10.0)
    val data = Seq(
      (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)),
      (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0))
    )
    val df = spark.createDataFrame(data).toDF("features", "expected")
    val ep = new ElementwiseProduct()
      .setInputCol("features")
      .setOutputCol("actual")
      .setScalingVec(scalingVec)
    testTransformer[(Vector, Vector)](df, ep, "actual", "expected") {
      case Row(actual: Vector, expected: Vector) =>
        assert(actual ~== expected relTol 1e-14)
    }
  }

  test("read/write") {
    val ep = new ElementwiseProduct()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setScalingVec(Vectors.dense(0.1, 0.2))
    testDefaultReadWrite(ep)
  }
} 
Example 12
Source File: BinarizerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}

class BinarizerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  @transient var data: Array[Double] = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
  }

  test("params") {
    ParamsSuite.checkParams(new Binarizer)
  }

  test("Binarize continuous features with default parameter") {
    val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
    val dataFrame: DataFrame = data.zip(defaultBinarized).toSeq.toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")

    testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") {
      case Row(x: Double, y: Double) =>
        assert(x === y, "The feature value is not correct after binarization.")
    }
  }

  test("Binarize continuous features with setter") {
    val threshold: Double = 0.2
    val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
    val dataFrame: DataFrame = data.zip(thresholdBinarized).toSeq.toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")
      .setThreshold(threshold)

    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
      case Row(x: Double, y: Double) =>
        assert(x === y, "The feature value is not correct after binarization.")
    }
  }

  test("Binarize vector of continuous features with default parameter") {
    val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
    val dataFrame: DataFrame = Seq(
      (Vectors.dense(data), Vectors.dense(defaultBinarized))
    ).toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")

    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
      case Row(x: Vector, y: Vector) =>
        assert(x == y, "The feature value is not correct after binarization.")
    }
  }

  test("Binarize vector of continuous features with setter") {
    val threshold: Double = 0.2
    val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
    val dataFrame: DataFrame = Seq(
      (Vectors.dense(data), Vectors.dense(defaultBinarized))
    ).toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")
      .setThreshold(threshold)

    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
      case Row(x: Vector, y: Vector) =>
        assert(x == y, "The feature value is not correct after binarization.")
    }
  }


  test("read/write") {
    val t = new Binarizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setThreshold(0.1)
    testDefaultReadWrite(t)
  }
} 
Example 13
Source File: TokenizerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}

@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])

class TokenizerSuite extends MLTest with DefaultReadWriteTest {

  test("params") {
    ParamsSuite.checkParams(new Tokenizer)
  }

  test("read/write") {
    val t = new Tokenizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
    testDefaultReadWrite(t)
  }
}

class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = {
    testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") {
      case Row(tokens, wantedTokens) =>
        assert(tokens === wantedTokens)
    }
  }

  test("params") {
    ParamsSuite.checkParams(new RegexTokenizer)
  }

  test("RegexTokenizer") {
    val tokenizer0 = new RegexTokenizer()
      .setGaps(false)
      .setPattern("\\w+|\\p{Punct}")
      .setInputCol("rawText")
      .setOutputCol("tokens")
    val dataset0 = Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
      TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
    ).toDF()
    testRegexTokenizer(tokenizer0, dataset0)

    val dataset1 = Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
      TokenizerTestData("Te,st. punct", Array("punct"))
    ).toDF()
    tokenizer0.setMinTokenLength(3)
    testRegexTokenizer(tokenizer0, dataset1)

    val tokenizer2 = new RegexTokenizer()
      .setInputCol("rawText")
      .setOutputCol("tokens")
    val dataset2 = Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
      TokenizerTestData("Te,st.  punct", Array("te,st.", "punct"))
    ).toDF()
    testRegexTokenizer(tokenizer2, dataset2)
  }

  test("RegexTokenizer with toLowercase false") {
    val tokenizer = new RegexTokenizer()
      .setInputCol("rawText")
      .setOutputCol("tokens")
      .setToLowercase(false)
    val dataset = Seq(
      TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
      TokenizerTestData("java scala", Array("java", "scala"))
    ).toDF()
    testRegexTokenizer(tokenizer, dataset)
  }

  test("read/write") {
    val t = new RegexTokenizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMinTokenLength(2)
      .setGaps(false)
      .setPattern("hi")
      .setToLowercase(false)
    testDefaultReadWrite(t)
  }
} 
Example 14
Source File: MinMaxScalerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row

class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("MinMaxScaler fit basic case") {
    val data = Array(
      Vectors.dense(1, 0, Long.MinValue),
      Vectors.dense(2, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)),
      Vectors.sparse(3, Array(0), Array(1.5)))

    val expected: Array[Vector] = Array(
      Vectors.dense(-5, 0, -5),
      Vectors.dense(0, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(5, 5)),
      Vectors.sparse(3, Array(0), Array(-2.5)))

    val df = data.zip(expected).toSeq.toDF("features", "expected")
    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("scaled")
      .setMin(-5)
      .setMax(5)

    val model = scaler.fit(df)
    testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
      case Row(vector1: Vector, vector2: Vector) =>
        assert(vector1 === vector2, "Transformed vector is different with expected.")
    }

    MLTestingUtils.checkCopyAndUids(scaler, model)
  }

  test("MinMaxScaler arguments max must be larger than min") {
    withClue("arguments max must be larger than min") {
      val dummyDF = Seq((1, Vectors.dense(1.0, 2.0))).toDF("id", "features")
      intercept[IllegalArgumentException] {
        val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("features")
        scaler.transformSchema(dummyDF.schema)
      }
      intercept[IllegalArgumentException] {
        val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("features")
        scaler.transformSchema(dummyDF.schema)
      }
    }
  }

  test("MinMaxScaler read/write") {
    val t = new MinMaxScaler()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMax(1.0)
      .setMin(-1.0)
    testDefaultReadWrite(t)
  }

  test("MinMaxScalerModel read/write") {
    val instance = new MinMaxScalerModel(
        "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0))
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMin(-1.0)
      .setMax(1.0)
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.originalMin === instance.originalMin)
    assert(newInstance.originalMax === instance.originalMax)
  }

  test("MinMaxScaler should remain NaN value") {
    val data = Array(
      Vectors.dense(1, Double.NaN, 2.0, 2.0),
      Vectors.dense(2, 2.0, 0.0, 3.0),
      Vectors.dense(3, Double.NaN, 0.0, 1.0),
      Vectors.dense(6, 2.0, 2.0, Double.NaN))

    val expected: Array[Vector] = Array(
      Vectors.dense(-5.0, Double.NaN, 5.0, 0.0),
      Vectors.dense(-3.0, 0.0, -5.0, 5.0),
      Vectors.dense(-1.0, Double.NaN, -5.0, -5.0),
      Vectors.dense(5.0, 0.0, 5.0, Double.NaN))

    val df = data.zip(expected).toSeq.toDF("features", "expected")
    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("scaled")
      .setMin(-5)
      .setMax(5)

    val model = scaler.fit(df)
    model.transform(df).select("expected", "scaled").collect()
      .foreach { case Row(vector1: Vector, vector2: Vector) =>
        assert(vector1 === vector2, "Transformed vector is different with expected.")
      }
  }
} 
Example 15
Source File: NormalizerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}


class NormalizerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  @transient var data: Array[Vector] = _
  @transient var l1Normalized: Array[Vector] = _
  @transient var l2Normalized: Array[Vector] = _

  override def beforeAll(): Unit = {
    super.beforeAll()

    data = Array(
      Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.6, -1.1, -3.0),
      Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))),
      Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))),
      Vectors.sparse(3, Seq())
    )
    l1Normalized = Array(
      Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.12765957, -0.23404255, -0.63829787),
      Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))),
      Vectors.dense(0.625, 0.07894737, 0.29605263),
      Vectors.sparse(3, Seq())
    )
    l2Normalized = Array(
      Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.184549876, -0.3383414, -0.922749378),
      Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))),
      Vectors.dense(0.897906166, 0.113419726, 0.42532397),
      Vectors.sparse(3, Seq())
    )
  }

  def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
    assert((lhs, rhs) match {
      case (v1: DenseVector, v2: DenseVector) => true
      case (v1: SparseVector, v2: SparseVector) => true
      case _ => false
    }, "The vector type should be preserved after normalization.")
  }

  def assertValues(lhs: Vector, rhs: Vector): Unit = {
    assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.")
  }

  test("Normalization with default parameter") {
    val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized")
    val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected")

    testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
      case Row(features: Vector, normalized: Vector, expected: Vector) =>
        assertTypeOfVector(normalized, features)
        assertValues(normalized, expected)
    }
  }

  test("Normalization with setter") {
    val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected")
    val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1)

    testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
      case Row(features: Vector, normalized: Vector, expected: Vector) =>
        assertTypeOfVector(normalized, features)
        assertValues(normalized, expected)
    }
  }

  test("read/write") {
    val t = new Normalizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setP(3.0)
    testDefaultReadWrite(t)
  }
} 
Example 16
Source File: NGramSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}


@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])

class NGramSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("default behavior yields bigram features") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
    val dataset = Seq(NGramTestData(
      Array("Test", "for", "ngram", "."),
      Array("Test for", "for ngram", "ngram .")
    )).toDF()
    testNGram(nGram, dataset)
  }

  test("NGramLength=4 yields length 4 n-grams") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(4)
    val dataset = Seq(NGramTestData(
      Array("a", "b", "c", "d", "e"),
      Array("a b c d", "b c d e")
    )).toDF()
    testNGram(nGram, dataset)
  }

  test("empty input yields empty output") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(4)
    val dataset = Seq(NGramTestData(Array(), Array())).toDF()
    testNGram(nGram, dataset)
  }

  test("input array < n yields empty output") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(6)
    val dataset = Seq(NGramTestData(
      Array("a", "b", "c", "d", "e"),
      Array()
    )).toDF()
    testNGram(nGram, dataset)
  }

  test("read/write") {
    val t = new NGram()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setN(3)
    testDefaultReadWrite(t)
  }

  def testNGram(t: NGram, dataFrame: DataFrame): Unit = {
    testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") {
      case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) =>
        assert(actualNGrams === wantedNGrams)
    }
  }
} 
Example 17
Source File: PCASuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.sql.Row

class PCASuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("params") {
    ParamsSuite.checkParams(new PCA)
    val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
    val explainedVariance = Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector]
    val model = new PCAModel("pca", mat, explainedVariance)
    ParamsSuite.checkParams(model)
  }

  test("pca") {
    val data = Array(
      Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
    )

    val dataRDD = sc.parallelize(data, 2)

    val mat = new RowMatrix(dataRDD.map(OldVectors.fromML))
    val pc = mat.computePrincipalComponents(3)
    val expected = mat.multiply(pc).rows.map(_.asML)

    val df = dataRDD.zip(expected).toDF("features", "expected")

    val pca = new PCA()
      .setInputCol("features")
      .setOutputCol("pca_features")
      .setK(3)

    val pcaModel = pca.fit(df)

    MLTestingUtils.checkCopyAndUids(pca, pcaModel)
    testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") {
      case Row(result: Vector, expected: Vector) =>
        assert(result ~== expected absTol 1e-5,
          "Transformed vector is different with expected vector.")
    }
  }

  test("PCA read/write") {
    val t = new PCA()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setK(3)
    testDefaultReadWrite(t)
  }

  test("PCAModel read/write") {
    val instance = new PCAModel("myPCAModel",
      Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix],
      Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector])
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.pc === instance.pc)
  }
} 
Example 18
Source File: HashingTFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF}
import org.apache.spark.sql.Row
import org.apache.spark.util.Utils

class HashingTFSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._
  import HashingTFSuite.murmur3FeatureIdx

  test("params") {
    ParamsSuite.checkParams(new HashingTF)
  }

  test("hashingTF") {
    val numFeatures = 100
    // Assume perfect hash when computing expected features.
    def idx: Any => Int = murmur3FeatureIdx(numFeatures)
    val data = Seq(
      ("a a b b c d".split(" ").toSeq,
        Vectors.sparse(numFeatures,
          Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))))
    )

    val df = data.toDF("words", "expected")
    val hashingTF = new HashingTF()
      .setInputCol("words")
      .setOutputCol("features")
      .setNumFeatures(numFeatures)
    val output = hashingTF.transform(df)
    val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
    require(attrGroup.numAttributes === Some(numFeatures))

    testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") {
      case Row(features: Vector, expected: Vector) =>
        assert(features ~== expected absTol 1e-14)
    }
  }

  test("applying binary term freqs") {
    val df = Seq((0, "a a b c c c".split(" ").toSeq)).toDF("id", "words")
    val n = 100
    val hashingTF = new HashingTF()
        .setInputCol("words")
        .setOutputCol("features")
        .setNumFeatures(n)
        .setBinary(true)
    val output = hashingTF.transform(df)
    val features = output.select("features").first().getAs[Vector](0)
    def idx: Any => Int = murmur3FeatureIdx(n)  // Assume perfect hash on input features
    val expected = Vectors.sparse(n,
      Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
    assert(features ~== expected absTol 1e-14)
  }

  test("read/write") {
    val t = new HashingTF()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setNumFeatures(10)
    testDefaultReadWrite(t)
  }

}

object HashingTFSuite {

  private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
    Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures)
  }

} 
Example 19
Source File: CorrelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.stat

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}


class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

  val xData = Array(1.0, 0.0, -2.0)
  val yData = Array(4.0, 5.0, 3.0)
  val zeros = new Array[Double](3)
  val data = Seq(
    Vectors.dense(1.0, 0.0, 0.0, -2.0),
    Vectors.dense(4.0, 5.0, 0.0, 3.0),
    Vectors.dense(6.0, 7.0, 0.0, 8.0),
    Vectors.dense(9.0, 0.0, 0.0, 1.0)
  )

  private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")

  private def extract(df: DataFrame): BDM[Double] = {
    val Array(Row(mat: Matrix)) = df.collect()
    mat.asBreeze.toDenseMatrix
  }


  test("corr(X) default, pearson") {
    val defaultMat = Correlation.corr(X, "features")
    val pearsonMat = Correlation.corr(X, "features", "pearson")
    // scalastyle:off
    val expected = Matrices.fromBreeze(BDM(
      (1.00000000, 0.05564149, Double.NaN, 0.4004714),
      (0.05564149, 1.00000000, Double.NaN, 0.9135959),
      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
      (0.40047142, 0.91359586, Double.NaN, 1.0000000)))
    // scalastyle:on

    assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
    assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
  }

  test("corr(X) spearman") {
    val spearmanMat = Correlation.corr(X, "features", "spearman")
    // scalastyle:off
    val expected = Matrices.fromBreeze(BDM(
      (1.0000000,  0.1054093,  Double.NaN, 0.4000000),
      (0.1054093,  1.0000000,  Double.NaN, 0.9486833),
      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
      (0.4000000,  0.9486833,  Double.NaN, 1.0000000)))
    // scalastyle:on
    assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
  }

} 
Example 20
Source File: MLTestSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.util

import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Row

class MLTestSuite extends MLTest {

  import testImplicits._

  test("test transformer on stream data") {

    val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
      .toDF("id", "label")
    val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
      .setInputCol("label").setOutputCol("indexed")
    val indexerModel = indexer.fit(data)
    testTransformer[(Int, String)](data, indexerModel, "id", "indexed") {
      case Row(id: Int, indexed: Double) =>
        assert(id === indexed.toInt)
    }
    testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id", "indexed") { rows =>
      assert(rows.map(_.getDouble(1)).max === 5.0)
    }

    intercept[Exception] {
      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
        case Row(id: Int, indexed: Double) =>
          assert(id != indexed.toInt)
      }
    }
    intercept[Exception] {
      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
        rows: Seq[Row] =>
          assert(rows.map(_.getDouble(1)).max === 1.0)
      }
    }
  }
} 
Example 21
Source File: ImageSchemaSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.image

import java.nio.file.Paths
import java.util.Arrays

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.image.ImageSchema._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
  // Single column of images named "image"
  private lazy val imagePath = "../data/mllib/images"

  test("Smoke test: create basic ImageSchema dataframe") {
    val origin = "path"
    val width = 1
    val height = 1
    val nChannels = 3
    val data = Array[Byte](0, 0, 0)
    val mode = ocvTypes("CV_8UC3")

    // Internal Row corresponds to image StructType
    val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)),
      Row(Row(null, height, width, nChannels, mode, data)))
    val rdd = sc.makeRDD(rows)
    val df = spark.createDataFrame(rdd, ImageSchema.imageSchema)

    assert(df.count === 2, "incorrect image count")
    assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema")
  }

  test("readImages count test") {
    var df = readImages(imagePath)
    assert(df.count === 1)

    df = readImages(imagePath, null, true, -1, false, 1.0, 0)
    assert(df.count === 10)

    df = readImages(imagePath, null, true, -1, true, 1.0, 0)
    val countTotal = df.count
    assert(countTotal === 8)

    df = readImages(imagePath, null, true, -1, true, 0.5, 0)
    // Random number about half of the size of the original dataset
    val count50 = df.count
    assert(count50 > 0 && count50 < countTotal)
  }

  test("readImages partition test") {
    val df = readImages(imagePath, null, true, 3, true, 1.0, 0)
    assert(df.rdd.getNumPartitions === 3)
  }

  // Images with the different number of channels
  test("readImages pixel values test") {

    val images = readImages(imagePath + "/multi-channel/").collect

    images.foreach { rrow =>
      val row = rrow.getAs[Row](0)
      val filename = Paths.get(getOrigin(row)).getFileName().toString()
      if (firstBytes20.contains(filename)) {
        val mode = getMode(row)
        val bytes20 = getData(row).slice(0, 20)

        val (expectedMode, expectedBytes) = firstBytes20(filename)
        assert(ocvTypes(expectedMode) === mode, "mode of the image is not read correctly")
        assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image")
      }
    }
  }

  // number of channels and first 20 bytes of OpenCV representation
  // - default representation for 3-channel RGB images is BGR row-wise:
  //   (B00, G00, R00,      B10, G10, R10,      ...)
  // - default representation for 4-channel RGB images is BGRA row-wise:
  //   (B00, G00, R00, A00, B10, G10, R10, A00, ...)
  private val firstBytes20 = Map(
    "grayscale.jpg" ->
      (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62,
        -57, -60, -63, -53, -49, -55, -69))),
    "chr30.4.184.jpg" -> (("CV_8UC3",
      Array[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57,
        -71, -58, -56, -73, -64))),
    "BGRA.png" -> (("CV_8UC4",
      Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128,
        -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))),
    "BGRA_alpha_60.png" -> (("CV_8UC4",
      Array[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128,
        -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60)))
  )
} 
Example 22
Source File: KafkaContinuousSourceSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}

// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest

class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
  import testImplicits._

  override val brokerProps = Map("auto.create.topics.enable" -> "false")

  test("subscribing topic by pattern with topic deletions") {
    val topicPrefix = newTopic()
    val topic = topicPrefix + "-seems"
    val topic2 = topicPrefix + "-bad"
    testUtils.createTopic(topic, partitions = 5)
    testUtils.sendMessages(topic, Array("-1"))
    require(testUtils.getLatestOffsets(Set(topic)).size === 5)

    val reader = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
      .option("kafka.metadata.max.age.ms", "1")
      .option("subscribePattern", s"$topicPrefix-.*")
      .option("failOnDataLoss", "false")

    val kafka = reader.load()
      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
      .as[(String, String)]
    val mapped = kafka.map(kv => kv._2.toInt + 1)

    testStream(mapped)(
      makeSureGetOffsetCalled,
      AddKafkaData(Set(topic), 1, 2, 3),
      CheckAnswer(2, 3, 4),
      Execute { query =>
        testUtils.deleteTopic(topic)
        testUtils.createTopic(topic2, partitions = 5)
        eventually(timeout(streamingTimeout)) {
          assert(
            query.lastExecution.logical.collectFirst {
              case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
            }.exists { r =>
              // Ensure the new topic is present and the old topic is gone.
              r.knownPartitions.exists(_.topic == topic2)
            },
            s"query never reconfigured to new topic $topic2")
        }
      },
      AddKafkaData(Set(topic2), 4, 5, 6),
      CheckAnswer(2, 3, 4, 5, 6, 7)
    )
  }
}

class KafkaContinuousSourceStressForDontFailOnDataLossSuite
    extends KafkaSourceStressForDontFailOnDataLossSuite {
  override protected def startStream(ds: Dataset[Int]) = {
    ds.writeStream
      .format("memory")
      .queryName("memory")
      .trigger(Trigger.Continuous("1 second"))
      .start()
  }
} 
Example 23
Source File: InsertIntoHiveDirCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.language.existentials

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.FileUtils
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.mapred._

import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.client.HiveClientImpl


case class InsertIntoHiveDirCommand(
    isLocal: Boolean,
    storage: CatalogStorageFormat,
    query: LogicalPlan,
    overwrite: Boolean,
    outputColumns: Seq[Attribute]) extends SaveAsHiveFile {

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    assert(storage.locationUri.nonEmpty)

    val hiveTable = HiveClientImpl.toHiveTable(CatalogTable(
      identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")),
      tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW,
      storage = storage,
      schema = query.schema
    ))
    hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB,
      storage.serde.getOrElse(classOf[LazySimpleSerDe].getName))

    val tableDesc = new TableDesc(
      hiveTable.getInputFormatClass,
      hiveTable.getOutputFormatClass,
      hiveTable.getMetadata
    )

    val hadoopConf = sparkSession.sessionState.newHadoopConf()
    val jobConf = new JobConf(hadoopConf)

    val targetPath = new Path(storage.locationUri.get)
    val writeToPath =
      if (isLocal) {
        val localFileSystem = FileSystem.getLocal(jobConf)
        localFileSystem.makeQualified(targetPath)
      } else {
        val qualifiedPath = FileUtils.makeQualified(targetPath, hadoopConf)
        val dfs = qualifiedPath.getFileSystem(jobConf)
        if (!dfs.exists(qualifiedPath)) {
          dfs.mkdirs(qualifiedPath.getParent)
        }
        qualifiedPath
      }

    val tmpPath = getExternalTmpPath(sparkSession, hadoopConf, writeToPath)
    val fileSinkConf = new org.apache.spark.sql.hive.HiveShim.ShimFileSinkDesc(
      tmpPath.toString, tableDesc, false)

    try {
      saveAsHiveFile(
        sparkSession = sparkSession,
        plan = child,
        hadoopConf = hadoopConf,
        fileSinkConf = fileSinkConf,
        outputLocation = tmpPath.toString,
        allColumns = outputColumns)

      val fs = writeToPath.getFileSystem(hadoopConf)
      if (overwrite && fs.exists(writeToPath)) {
        fs.listStatus(writeToPath).foreach { existFile =>
          if (Option(existFile.getPath) != createdTempDir) fs.delete(existFile.getPath, true)
        }
      }

      fs.listStatus(tmpPath).foreach {
        tmpFile => fs.rename(tmpFile.getPath, writeToPath)
      }
    } catch {
      case e: Throwable =>
        throw new SparkException(
          "Failed inserting overwrite directory " + storage.locationUri.get, e)
    } finally {
      deleteExternalTmpPath(hadoopConf)
    }

    Seq.empty[Row]
  }
} 
Example 24
Source File: CreateHiveTableAsSelectCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand



case class CreateHiveTableAsSelectCommand(
    tableDesc: CatalogTable,
    query: LogicalPlan,
    outputColumns: Seq[Attribute],
    mode: SaveMode)
  extends DataWritingCommand {

  private val tableIdentifier = tableDesc.identifier

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog
    if (catalog.tableExists(tableIdentifier)) {
      assert(mode != SaveMode.Overwrite,
        s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite")

      if (mode == SaveMode.ErrorIfExists) {
        throw new AnalysisException(s"$tableIdentifier already exists.")
      }
      if (mode == SaveMode.Ignore) {
        // Since the table already exists and the save mode is Ignore, we will just return.
        return Seq.empty
      }

      InsertIntoHiveTable(
        tableDesc,
        Map.empty,
        query,
        overwrite = false,
        ifPartitionNotExists = false,
        outputColumns = outputColumns).run(sparkSession, child)
    } else {
      // TODO ideally, we should get the output data ready first and then
      // add the relation into catalog, just in case of failure occurs while data
      // processing.
      assert(tableDesc.schema.isEmpty)
      catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false)

      try {
        // Read back the metadata of the table which was created just now.
        val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
        // For CTAS, there is no static partition values to insert.
        val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap
        InsertIntoHiveTable(
          createdTableMeta,
          partition,
          query,
          overwrite = true,
          ifPartitionNotExists = false,
          outputColumns = outputColumns).run(sparkSession, child)
      } catch {
        case NonFatal(e) =>
          // drop the created table.
          catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false)
          throw e
      }
    }

    Seq.empty[Row]
  }

  override def argString: String = {
    s"[Database:${tableDesc.database}}, " +
    s"TableName: ${tableDesc.identifier.table}, " +
    s"InsertIntoHiveTable]"
  }
} 
Example 25
Source File: HiveParquetSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.hive.test.TestHiveSingleton

case class Cases(lower: String, UPPER: String)

class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton {

  test("Case insensitive attribute names") {
    withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
      val expected = (1 to 4).map(i => Row(i.toString))
      checkAnswer(sql("SELECT upper FROM cases"), expected)
      checkAnswer(sql("SELECT LOWER FROM cases"), expected)
    }
  }

  test("SELECT on Parquet table") {
    val data = (1 to 4).map(i => (i, s"val_$i"))
    withParquetTable(data, "t") {
      checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple))
    }
  }

  test("Simple column projection + filter on Parquet table") {
    withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") {
      checkAnswer(
        sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"),
        Seq(Row(true, "val_2"), Row(true, "val_4")))
    }
  }

  test("Converting Hive to Parquet Table via saveAsParquetFile") {
    withTempPath { dir =>
      sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
      spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("p")
      withTempView("p") {
        checkAnswer(
          sql("SELECT * FROM src ORDER BY key"),
          sql("SELECT * from p ORDER BY key").collect().toSeq)
      }
    }
  }

  test("INSERT OVERWRITE TABLE Parquet table") {
    // Don't run with vectorized: currently relies on UnsafeRow.
    withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) {
      withTempPath { file =>
        sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
        spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("p")
        withTempView("p") {
          // let's do three overwrites for good measure
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq)
        }
      }
    }
  }
} 
Example 26
Source File: HiveDataFrameJoinSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton

class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton {
  import spark.implicits._

  // We should move this into SQL package if we make case sensitivity configurable in SQL.
  test("join - self join auto resolve ambiguity with case insensitivity") {
    val df = Seq((1, "1"), (2, "2")).toDF("key", "value")
    checkAnswer(
      df.join(df, df("key") === df("Key")),
      Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil)

    checkAnswer(
      df.join(df.filter($"value" === "2"), df("key") === df("Key")),
      Row(2, "2", 2, "2") :: Nil)
  }

} 
Example 27
Source File: ListTablesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.hive.test.TestHiveSingleton

class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
  import hiveContext._
  import hiveContext.implicits._

  val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")

  override def beforeAll(): Unit = {
    super.beforeAll()
    // The catalog in HiveContext is a case insensitive one.
    sessionState.catalog.createTempView(
      "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true)
    sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
    sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
    sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
  }

  override def afterAll(): Unit = {
    try {
      sessionState.catalog.dropTable(
        TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true, purge = false)
      sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
      sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
      sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
    } finally {
      super.afterAll()
    }
  }

  test("get all tables of current database") {
    Seq(tables(), sql("SHOW TABLes")).foreach {
      case allTables =>
        // We are using default DB.
        checkAnswer(
          allTables.filter("tableName = 'listtablessuitetable'"),
          Row("", "listtablessuitetable", true))
        checkAnswer(
          allTables.filter("tableName = 'hivelisttablessuitetable'"),
          Row("default", "hivelisttablessuitetable", false))
        assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
    }
  }

  test("getting all tables with a database name") {
    Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach {
      case allTables =>
        checkAnswer(
          allTables.filter("tableName = 'listtablessuitetable'"),
          Row("", "listtablessuitetable", true))
        assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
        checkAnswer(
          allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
          Row("listtablessuitedb", "hiveindblisttablessuitetable", false))
    }
  }
} 
Example 28
Source File: HiveVariableSubstitutionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton

class HiveVariableSubstitutionSuite extends QueryTest with TestHiveSingleton {
  test("SET hivevar with prefix") {
    spark.sql("SET hivevar:county=gram")
    assert(spark.conf.getOption("county") === Some("gram"))
  }

  test("SET hivevar with dotted name") {
    spark.sql("SET hivevar:eloquent.mosquito.alphabet=zip")
    assert(spark.conf.getOption("eloquent.mosquito.alphabet") === Some("zip"))
  }

  test("hivevar substitution") {
    spark.conf.set("pond", "bus")
    checkAnswer(spark.sql("SELECT '${hivevar:pond}'"), Row("bus") :: Nil)
  }

  test("variable substitution without a prefix") {
    spark.sql("SET hivevar:flask=plaid")
    checkAnswer(spark.sql("SELECT '${flask}'"), Row("plaid") :: Nil)
  }

  test("variable substitution precedence") {
    spark.conf.set("turn.aloof", "questionable")
    spark.sql("SET hivevar:turn.aloof=dime")
    // hivevar clobbers the conf setting
    checkAnswer(spark.sql("SELECT '${turn.aloof}'"), Row("dime") :: Nil)
  }
} 
Example 29
Source File: JsonHadoopFsRelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.math.BigDecimal

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.types._

class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
  override val dataSourceName: String = "json"

  // JSON does not write data of NullType and does not play well with BinaryType.
  override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
    case _: NullType => false
    case _: BinaryType => false
    case _: CalendarIntervalType => false
    case _ => true
  }

  test("save()/load() - partitioned table - simple queries - partition columns in data") {
    withTempDir { file =>
      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(
          CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
        sparkContext
          .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""")
          .saveAsTextFile(partitionDir.toString)
      }

      val dataSchemaWithPartition =
        StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

      checkQueries(
        spark.read.format(dataSourceName)
          .option("dataSchema", dataSchemaWithPartition.json)
          .load(file.getCanonicalPath))
    }
  }

  test("SPARK-9894: save complex types to JSON") {
    withTempDir { file =>
      file.delete()

      val schema =
        new StructType()
          .add("array", ArrayType(LongType))
          .add("map", MapType(StringType, new StructType().add("innerField", LongType)))

      val data =
        Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) ::
          Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil
      val df = spark.createDataFrame(sparkContext.parallelize(data), schema)

      // Write the data out.
      df.write.format(dataSourceName).save(file.getCanonicalPath)

      // Read it back and check the result.
      checkAnswer(
        spark.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
        df
      )
    }
  }

  test("SPARK-10196: save decimal type to JSON") {
    withTempDir { file =>
      file.delete()

      val schema =
        new StructType()
          .add("decimal", DecimalType(7, 2))

      val data =
        Row(new BigDecimal("10.02")) ::
          Row(new BigDecimal("20000.99")) ::
          Row(new BigDecimal("10000")) :: Nil
      val df = spark.createDataFrame(sparkContext.parallelize(data), schema)

      // Write the data out.
      df.write.format(dataSourceName).save(file.getCanonicalPath)

      // Read it back and check the result.
      checkAnswer(
        spark.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
        df
      )
    }
  }
} 
Example 30
Source File: LocalRelation.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
  def apply(output: Attribute*): LocalRelation = new LocalRelation(output)

  def apply(output1: StructField, output: StructField*): LocalRelation = {
    new LocalRelation(StructType(output1 +: output).toAttributes)
  }

  def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }

  def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }
}

case class LocalRelation(
    output: Seq[Attribute],
    data: Seq[InternalRow] = Nil,
    // Indicates whether this relation has data from a streaming source.
    override val isStreaming: Boolean = false)
  extends LeafNode with analysis.MultiInstanceRelation {

  // A local relation must have resolved output.
  require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.")

  
  override final def newInstance(): this.type = {
    LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type]
  }

  override protected def stringArgs: Iterator[Any] = {
    if (data.isEmpty) {
      Iterator("<empty>", output)
    } else {
      Iterator(output)
    }
  }

  override def computeStats(): Statistics =
    Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

  def toSQL(inlineTableName: String): String = {
    require(data.nonEmpty)
    val types = output.map(_.dataType)
    val rows = data.map { row =>
      val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql }
      cells.mkString("(", ", ", ")")
    }
    "VALUES " + rows.mkString(", ") +
      " AS " + inlineTableName +
      output.map(_.name).mkString("(", ", ", ")")
  }
} 
Example 31
Source File: CatalystTypeConvertersSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(
    StringType,
    DateType,
    BooleanType,
    ByteType,
    ShortType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType.SYSTEM_DEFAULT,
    DecimalType.USER_DEFAULT)

  test("null handling in rows") {
    val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
  }

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
    }
  }

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
  }

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
  }

  test("primitive array handling") {
    val intArray = Array(1, 100, 10000)
    val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
    val intArrayType = ArrayType(IntegerType, false)
    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)

    val doubleArray = Array(1.1, 111.1, 11111.1)
    val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, false)
    assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray)
      === doubleArray)
  }

  test("An array with null handling") {
    val intArray = Array(1, null, 100, null, 10000)
    val intGenericArray = new GenericArrayData(intArray)
    val intArrayType = ArrayType(IntegerType, true)
    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray)
      === intArray)
    assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray)
      == intGenericArray)

    val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
    val doubleGenericArray = new GenericArrayData(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, true)
    assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray)
      === doubleArray)
    assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
      == doubleGenericArray)
  }
} 
Example 32
Source File: DataSourceV2ScanExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType


case class DataSourceV2ScanExec(
    output: Seq[AttributeReference],
    @transient reader: DataSourceReader)
  extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {

  override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

  override def outputPartitioning: physical.Partitioning = reader match {
    case s: SupportsReportPartitioning =>
      new DataSourcePartitioning(
        s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))

    case _ => super.outputPartitioning
  }

  private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
    case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
    case _ =>
      reader.createDataReaderFactories().asScala.map {
        new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
      }.asJava
  }

  private lazy val inputRDD: RDD[InternalRow] = reader match {
    case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
      assert(!reader.isInstanceOf[ContinuousReader],
        "continuous stream reader does not support columnar read yet.")
      new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
        .asInstanceOf[RDD[InternalRow]]

    case _: ContinuousReader =>
      EpochCoordinatorRef.get(
          sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
          sparkContext.env)
        .askSync[Unit](SetReaderPartitions(readerFactories.size()))
      new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
        .asInstanceOf[RDD[InternalRow]]

    case _ =>
      new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
  }

  override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)

  override val supportsBatch: Boolean = reader match {
    case r: SupportsScanColumnarBatch if r.enableBatchRead() => true
    case _ => false
  }

  override protected def needsUnsafeRowConversion: Boolean = false

  override protected def doExecute(): RDD[InternalRow] = {
    if (supportsBatch) {
      WholeStageCodegenExec(this)(codegenStageId = 0).execute()
    } else {
      val numOutputRows = longMetric("numOutputRows")
      inputRDD.map { r =>
        numOutputRows += 1
        r
      }
    }
  }
}

class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
  extends DataReaderFactory[UnsafeRow] {

  override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations

  override def createDataReader: DataReader[UnsafeRow] = {
    new RowToUnsafeDataReader(
      rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
  }
}

class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
  extends DataReader[UnsafeRow] {

  override def next: Boolean = rowReader.next

  override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow]

  override def close(): Unit = rowReader.close()
} 
Example 33
Source File: ParquetOutputWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter

// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
  extends OutputWriter {

  private val recordWriter: RecordWriter[Void, InternalRow] = {
    new ParquetOutputFormat[InternalRow]() {
      override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
        new Path(path)
      }
    }.getRecordWriter(context)
  }

  override def write(row: InternalRow): Unit = recordWriter.write(null, row)

  override def close(): Unit = recordWriter.close(context)
} 
Example 34
Source File: SaveIntoDataSourceCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider


case class SaveIntoDataSourceCommand(
    query: LogicalPlan,
    dataSource: CreatableRelationProvider,
    options: Map[String, String],
    mode: SaveMode) extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    dataSource.createRelation(
      sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))

    Seq.empty[Row]
  }

  override def simpleString: String = {
    val redacted = SQLConf.get.redactOptions(options)
    s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
  }
} 
Example 35
Source File: MapPartitionsRWrapper.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.r

import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.types.StructType


case class MapPartitionsRWrapper(
    func: Array[Byte],
    packageNames: Array[Byte],
    broadcastVars: Array[Broadcast[Object]],
    inputSchema: StructType,
    outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
  def apply(iter: Iterator[Any]): Iterator[Any] = {
    // If the content of current DataFrame is serialized R data?
    val isSerializedRData = inputSchema == SERIALIZED_R_DATA_SCHEMA

    val (newIter, deserializer, colNames) =
      if (!isSerializedRData) {
        // Serialize each row into a byte array that can be deserialized in the R worker
        (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
         SerializationFormats.ROW, inputSchema.fieldNames)
      } else {
        (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
      }

    val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
      SerializationFormats.ROW
    } else {
      SerializationFormats.BYTE
    }

    val runner = new RRunner[Array[Byte]](
      func, deserializer, serializer, packageNames, broadcastVars,
      isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
    // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
    val outputIter = runner.compute(newIter, -1)

    if (serializer == SerializationFormats.ROW) {
      outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
    } else {
      outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
    }
  }
} 
Example 36
Source File: FrequentItems.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 37
Source File: cache.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

case class CacheTableCommand(
    tableIdent: TableIdentifier,
    plan: Option[LogicalPlan],
    isLazy: Boolean) extends RunnableCommand {
  require(plan.isEmpty || tableIdent.database.isEmpty,
    "Database name is not allowed in CACHE TABLE AS SELECT")

  override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

  override def run(sparkSession: SparkSession): Seq[Row] = {
    plan.foreach { logicalPlan =>
      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
    }
    sparkSession.catalog.cacheTable(tableIdent.quotedString)

    if (!isLazy) {
      // Performs eager caching
      sparkSession.table(tableIdent).count()
    }

    Seq.empty[Row]
  }
}


case class UncacheTableCommand(
    tableIdent: TableIdentifier,
    ifExists: Boolean) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val tableId = tableIdent.quotedString
    if (!ifExists || sparkSession.catalog.tableExists(tableId)) {
      sparkSession.catalog.uncacheTable(tableId)
    }
    Seq.empty[Row]
  }
}


  override def makeCopy(newArgs: Array[AnyRef]): ClearCacheCommand = ClearCacheCommand()
} 
Example 38
Source File: resources.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 39
Source File: AnalyzeTableCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTableType



case class AnalyzeTableCommand(
    tableIdent: TableIdentifier,
    noscan: Boolean = true) extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val sessionState = sparkSession.sessionState
    val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
    val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
    val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB)
    if (tableMeta.tableType == CatalogTableType.VIEW) {
      throw new AnalysisException("ANALYZE TABLE is not supported on views.")
    }

    // Compute stats for the whole table
    val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta)
    val newRowCount =
      if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count()))

    // Update the metastore if the above statistics of the table are different from those
    // recorded in the metastore.
    val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount)
    if (newStats.isDefined) {
      sessionState.catalog.alterTableStats(tableIdentWithDB, newStats)
    }

    Seq.empty[Row]
  }
} 
Example 40
Source File: MicroBatchWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter


class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter {
  override def commit(messages: Array[WriterCommitMessage]): Unit = {
    writer.commit(batchId, messages)
  }

  override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

  override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
}

class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)
  extends DataSourceWriter with SupportsWriteInternalRow {
  override def commit(messages: Array[WriterCommitMessage]): Unit = {
    writer.commit(batchId, messages)
  }

  override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

  override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] =
    writer match {
      case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
      case _ => throw new IllegalStateException(
        "InternalRowMicroBatchWriter should only be created with base writer support")
    }
} 
Example 41
Source File: ConsoleWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.sources

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType


class ConsoleWriter(schema: StructType, options: DataSourceOptions)
    extends StreamWriter with Logging {

  // Number of rows to display, by default 20 rows
  protected val numRowsToShow = options.getInt("numRows", 20)

  // Truncate the displayed data if it is too long, by default it is true
  protected val isTruncated = options.getBoolean("truncate", true)

  assert(SparkSession.getActiveSession.isDefined)
  protected val spark = SparkSession.getActiveSession.get

  def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory

  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
    // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
    // behavior.
    printRows(messages, schema, s"Batch: $epochId")
  }

  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

  protected def printRows(
      commitMessages: Array[WriterCommitMessage],
      schema: StructType,
      printMessage: String): Unit = {
    val rows = commitMessages.collect {
      case PackedRowCommitMessage(rs) => rs
    }.flatten

    // scalastyle:off println
    println("-------------------------------------------")
    println(printMessage)
    println("-------------------------------------------")
    // scalastyle:off println
    spark
      .createDataFrame(rows.toList.asJava, schema)
      .show(numRowsToShow, isTruncated)
  }

  override def toString(): String = {
    s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
  }
} 
Example 42
Source File: RowDataSourceStrategySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE OR REPLACE TEMPORARY VIEW inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 43
Source File: FileFormatWriterSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext

class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("empty file should be skipped while write to file") {
    withTempPath { path =>
      spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
      val partFiles = path.listFiles()
        .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
      assert(partFiles.length === 2)
    }
  }

  test("SPARK-22252: FileFormatWriter should respect the input query schema") {
    withTable("t1", "t2", "t3", "t4") {
      spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
      spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
      checkAnswer(spark.table("t2"), Row(0, 0))

      // Test picking part of the columns when writing.
      spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
      spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
      checkAnswer(spark.table("t4"), Row(0, 0))
    }
  }
} 
Example 44
Source File: ParquetProtobufCompatibilitySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 45
Source File: TakeOrderedAndProjectSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 46
Source File: GroupedIteratorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 47
Source File: MemorySinkV2Suite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}

class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
  test("data writer") {
    val partition = 1234
    val writer = new MemoryDataWriter(partition, OutputMode.Append())
    writer.write(Row(1))
    writer.write(Row(2))
    writer.write(Row(44))
    val msg = writer.commit()
    assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
    assert(msg.partition == partition)

    // Buffer should be cleared, so repeated commits should give empty.
    assert(writer.commit().data.isEmpty)
  }

  test("continuous writer") {
    val sink = new MemorySinkV2
    val writer = new MemoryStreamWriter(sink, OutputMode.Append())
    writer.commit(0,
      Array(
        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
      ))
    assert(sink.latestBatchId.contains(0))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
    writer.commit(19,
      Array(
        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
        MemoryWriterCommitMessage(0, Seq(Row(33)))
      ))
    assert(sink.latestBatchId.contains(19))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))

    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
  }

  test("microbatch writer") {
    val sink = new MemorySinkV2
    new MemoryWriter(sink, 0, OutputMode.Append()).commit(
      Array(
        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
      ))
    assert(sink.latestBatchId.contains(0))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
    new MemoryWriter(sink, 19, OutputMode.Append()).commit(
      Array(
        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
        MemoryWriterCommitMessage(0, Seq(Row(33)))
      ))
    assert(sink.latestBatchId.contains(19))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))

    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
  }
} 
Example 48
Source File: SpreadsheetRelation.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.datasource.google.spreadsheet

import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService.SparkSpreadsheetContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService._

  private val fieldMap = scala.collection.mutable.Map[String, String]()
  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheet] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    val schemaMap = fieldMap.toMap
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else if (schemaMap.contains(field.name) && m.contains(schemaMap(field.name))) {
            TypeCast.castTo(m(schemaMap(field.name)), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  def sanitizeColumnName(name: String): String =
  {
    name
      .replaceAll("[^a-zA-Z0-9]+", "_")    // Replace sequences of non-alphanumeric characters with underscores
      .replaceAll("_+$", "")               // Strip trailing underscores
      .replaceAll("^[0-9_]+", "")          // Strip leading underscores and digits
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName => {
      val sanitizedName = sanitizeColumnName(fieldName)
      fieldMap.put(sanitizedName, fieldName)
      StructField(sanitizedName, StringType, true)
    }})

} 
Example 49
Source File: Util.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.datasource.google.spreadsheet

import com.google.api.services.sheets.v4.model.{ExtendedValue, CellData, RowData}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DataTypes, StructType}

import scala.collection.JavaConverters._

object Util {
  def convert(schema: StructType, row: Row): Map[String, Object] =
    schema.iterator.zipWithIndex.map { case (f, i) => f.name -> row(i).asInstanceOf[AnyRef]}.toMap

  def toRowData(row: Row): RowData =
      new RowData().setValues(
        row.schema.fields.zipWithIndex.map { case (f, i) =>
          new CellData()
            .setUserEnteredValue(
              f.dataType match {
                case DataTypes.StringType => new ExtendedValue().setStringValue(row.getString(i))
                case DataTypes.LongType => new ExtendedValue().setNumberValue(row.getLong(i).toDouble)
                case DataTypes.IntegerType => new ExtendedValue().setNumberValue(row.getInt(i).toDouble)
                case DataTypes.FloatType => new ExtendedValue().setNumberValue(row.getFloat(i).toDouble)
                case DataTypes.BooleanType => new ExtendedValue().setBoolValue(row.getBoolean(i))
                case DataTypes.DateType => new ExtendedValue().setStringValue(row.getDate(i).toString)
                case DataTypes.ShortType => new ExtendedValue().setNumberValue(row.getShort(i).toDouble)
                case DataTypes.TimestampType => new ExtendedValue().setStringValue(row.getTimestamp(i).toString)
                case DataTypes.DoubleType => new ExtendedValue().setNumberValue(row.getDouble(i))
              }
            )
        }.toList.asJava
      )

} 
Example 50
Source File: SageMakerProtobufWriter.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.protobuf

import java.io.ByteArrayOutputStream

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, NullWritable}
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types.StructType


  def write(row: Row): Unit = {
    val labelColumnName = options.getOrElse("labelColumnName", "label")
    val featuresColumnName = options.getOrElse("featuresColumnName", "features")

    val record = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Some(labelColumnName))
    record.writeTo(byteArrayOutputStream)

    recordWriter.write(NullWritable.get(), new BytesWritable(byteArrayOutputStream.toByteArray))
    byteArrayOutputStream.reset()
  }

  override def close(): Unit = {
    recordWriter.close(context)
  }
} 
Example 51
Source File: LibSVMResponseRowDeserializer.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers

import org.apache.spark.ml.linalg.{SparseVector, SQLDataTypes}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.transformation.{ContentTypes, ResponseRowDeserializer}


  override val accepts: String = ContentTypes.TEXT_LIBSVM

  private def parseLibSVMRow(record: String): Row = {
    val items = record.split(' ')
    val label = items.head.toDouble
    val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
      val entry = item.split(':')
      val index = entry(0).toInt - 1
      val value = entry(1).toDouble
      (index, value)
    }.unzip
    Row(label, new SparseVector(dim, indices.toArray, values.toArray))
  }

  override val schema: StructType = StructType(
    Array(
      StructField(labelColumnName, DoubleType, nullable = false),
      StructField(featuresColumnName, SQLDataTypes.VectorType, nullable = false)))
} 
Example 52
Source File: ROC.scala    From s4ds   with Apache License 2.0 5 votes vote down vote up
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer, StringIndexer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

import breeze.linalg._
import breeze.plot._
import org.jfree.chart.axis.NumberTickUnit


object ROC extends App {

  val conf = new SparkConf().setAppName("ROC")
  val sc = new SparkContext(conf)
  val sqlContext = new SQLContext(sc)
  import sqlContext._
  import sqlContext.implicits._

  val transformedTest = sqlContext.read.parquet("transformedTest.parquet")

  val labelScores = transformedTest.select("probability", "label").map {
    case Row(probability:Vector, label:Double) => (probability(1), label)
  }

  val bm = new BinaryClassificationMetrics(labelScores, 300)
  val roc = bm.roc.collect
  
  roc.foreach { println }

  val falsePositives = roc.map { _._1 }
  val truePositives = roc.map { _._2 }

  val f = Figure()
  val p = f.subplot(0)
  p += plot(falsePositives, truePositives)
  p.xlabel = "false positives"
  p.ylabel = "true positives"
  p.xlim = (0.0, 0.1)
  p.xaxis.setTickUnit(new NumberTickUnit(0.01))
  p.yaxis.setTickUnit(new NumberTickUnit(0.1))
  f.refresh
  f.saveas("roc.png")
  

} 
Example 53
Source File: LogisticRegressionDemo.scala    From s4ds   with Apache License 2.0 5 votes vote down vote up
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer, StringIndexer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.SaveMode

case class LabelledDocument(fileName:String, text:String, category:String)

object LogisticRegressionDemo extends App {

  val conf = new SparkConf().setAppName("LrTest")
  val sc = new SparkContext(conf)
  val sqlContext = new SQLContext(sc)
  import sqlContext._
  import sqlContext.implicits._

  val spamText = sc.wholeTextFiles("spam/*")
  val hamText = sc.wholeTextFiles("ham/*")

  val spamDocuments = spamText.map { 
    case (fileName, text) => LabelledDocument(fileName, text, "spam")
  }
  val hamDocuments = hamText.map {
    case (fileName, text) => LabelledDocument(fileName, text, "ham")
  }

  val documentsDF = spamDocuments.union(hamDocuments).toDF
  documentsDF.persist

  val Array(trainDF, testDF) = documentsDF.randomSplit(Array(0.7, 0.3))

  val indexer = new StringIndexer().setInputCol("category").setOutputCol("label")
  val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
  val hasher = new HashingTF().setInputCol("words").setOutputCol("features")
  val lr = new LogisticRegression().setMaxIter(50).setRegParam(0.0)

  val pipeline = new Pipeline().setStages(Array(indexer, tokenizer, hasher, lr))
  val model = pipeline.fit(trainDF)

  val transformedTrain = model.transform(trainDF)
  transformedTrain.persist
  
  val transformedTest = model.transform(testDF)
  transformedTest.persist

  println("in sample misclassified:", transformedTrain.filter($"prediction" !== $"label").count,
    " / ",transformedTrain.count)
  println("out sample misclassified:", transformedTest.filter($"prediction" !== $"label").count,
    " / ",transformedTest.count)

  transformedTrain.select("fileName", "label", "prediction", "probability")
    .write.mode(SaveMode.Overwrite).parquet("transformedTrain.parquet")
  transformedTest.select("fileName", "label", "prediction", "probability")
    .write.mode(SaveMode.Overwrite).parquet("transformedTest.parquet")
} 
Example 54
Source File: VectorSlicerExample.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.feature.VectorSlicer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
// $example off$
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object VectorSlicerExample {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("VectorSlicerExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // $example on$
    val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0)))

    val defaultAttr = NumericAttribute.defaultAttr
    val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName)
    val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])

    val dataRDD = sc.parallelize(data)
    val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField())))

    val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")

    slicer.setIndices(Array(1)).setNames(Array("f3"))
    // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3"))

    val output = slicer.transform(dataset)
    println(output.select("userFeatures", "features").first())
    // $example off$
    sc.stop()
  }
}
// scalastyle:on println 
Example 55
Source File: DataFrameExample.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import java.io.File

import com.google.common.io.Files
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


object DataFrameExample {

  case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
    extends AbstractParams[Params]

  def main(args: Array[String]) {
    val defaultParams = Params()

    val parser = new OptionParser[Params]("DataFrameExample") {
      head("DataFrameExample: an example app using DataFrame for ML.")
      opt[String]("input")
        .text(s"input path to dataframe")
        .action((x, c) => c.copy(input = x))
      checkConfig { params =>
        success
      }
    }

    parser.parse(args, defaultParams).map { params =>
      run(params)
    }.getOrElse {
      sys.exit(1)
    }
  }

  def run(params: Params) {

    val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // Load input data
    println(s"Loading LIBSVM file with UDT from ${params.input}.")
    val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
    println("Schema from LIBSVM:")
    df.printSchema()
    println(s"Loaded training data as a DataFrame with ${df.count()} records.")

    // Show statistical summary of labels.
    val labelSummary = df.describe("label")
    labelSummary.show()

    // Convert features column to an RDD of vectors.
    val features = df.select("features").map { case Row(v: Vector) => v }
    val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
      (summary, feat) => summary.add(feat),
      (sum1, sum2) => sum1.merge(sum2))
    println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

    // Save the records in a parquet file.
    val tmpDir = Files.createTempDir()
    tmpDir.deleteOnExit()
    val outputDir = new File(tmpDir, "dataframe").toString
    println(s"Saving to $outputDir as Parquet file.")
    df.write.parquet(outputDir)

    // Load the records back.
    println(s"Loading Parquet file with UDT from $outputDir.")
    val newDF = sqlContext.read.parquet(outputDir)
    println(s"Schema from Parquet:")
    newDF.printSchema()

    sc.stop()
  }
}
// scalastyle:on println 
Example 56
Source File: SimpleTextClassificationPipeline.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

import scala.beans.BeanInfo

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}

@BeanInfo
case class LabeledDocument(id: Long, text: String, label: Double)

@BeanInfo
case class Document(id: Long, text: String)


object SimpleTextClassificationPipeline {

  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    // Prepare training documents, which are labeled.
    val training = sc.parallelize(Seq(
      LabeledDocument(0L, "a b c d e spark", 1.0),
      LabeledDocument(1L, "b d", 0.0),
      LabeledDocument(2L, "spark f g h", 1.0),
      LabeledDocument(3L, "hadoop mapreduce", 0.0)))

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setNumFeatures(1000)
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.001)
    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))

    // Fit the pipeline to training documents.
    val model = pipeline.fit(training.toDF())

    // Prepare test documents, which are unlabeled.
    val test = sc.parallelize(Seq(
      Document(4L, "spark i j k"),
      Document(5L, "l m n"),
      Document(6L, "spark hadoop spark"),
      Document(7L, "apache hadoop")))

    // Make predictions on test documents.
    model.transform(test.toDF())
      .select("id", "text", "probability", "prediction")
      .collect()
      .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
        println(s"($id, $text) --> prob=$prob, prediction=$prediction")
      }

    sc.stop()
  }
}
// scalastyle:on println 
Example 57
Source File: BigDatalogProgram.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import edu.ucla.cs.wis.bigdatalog.interpreter.OperatorProgram
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, Row}

class BigDatalogProgram(var bigDatalogContext: BigDatalogContext,
                        plan: LogicalPlan,
                        operatorProgram: OperatorProgram) {

  def toDF(): DataFrame = {
    new DataFrame(bigDatalogContext, plan)
  }
  
  def count(): Long = {
    toDF().count()
  }

  // use this method to produce an rdd containing the results for the program (i.e., it evaluates the program)
  def execute(): RDD[Row] = {
    toDF().rdd
  }

  override def toString(): String = {
    new QueryExecution(bigDatalogContext, plan).toString
  }
} 
Example 58
Source File: SQLTransformer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType


  @Since("1.6.0")
  def getStatement: String = $(statement)

  private val tableIdentifier: String = "__THIS__"

  @Since("1.6.0")
  override def transform(dataset: DataFrame): DataFrame = {
    val tableName = Identifiable.randomUID(uid)
    dataset.registerTempTable(tableName)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    val outputDF = dataset.sqlContext.sql(realStatement)
    outputDF
  }

  @Since("1.6.0")
  override def transformSchema(schema: StructType): StructType = {
    val sc = SparkContext.getOrCreate()
    val sqlContext = SQLContext.getOrCreate(sc)
    val dummyRDD = sc.parallelize(Seq(Row.empty))
    val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
    dummyDF.registerTempTable(tableIdentifier)
    val outputSchema = sqlContext.sql($(statement)).schema
    outputSchema
  }

  @Since("1.6.0")
  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}

@Since("1.6.0")
object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {

  @Since("1.6.0")
  override def load(path: String): SQLTransformer = super.load(path)
} 
Example 59
Source File: BinaryClassificationEvaluator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType


  @Since("1.2.0")
  def setLabelCol(value: String): this.type = set(labelCol, value)

  setDefault(metricName -> "areaUnderROC")

  @Since("1.2.0")
  override def evaluate(dataset: DataFrame): Double = {
    val schema = dataset.schema
    SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT)
    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)

    // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
    val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol))
      .map { case Row(rawPrediction: Vector, label: Double) =>
        (rawPrediction(1), label)
      }
    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
    val metric = $(metricName) match {
      case "areaUnderROC" => metrics.areaUnderROC()
      case "areaUnderPR" => metrics.areaUnderPR()
    }
    metrics.unpersist()
    metric
  }

  @Since("1.5.0")
  override def isLargerBetter: Boolean = $(metricName) match {
    case "areaUnderROC" => true
    case "areaUnderPR" => true
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {

  @Since("1.6.0")
  override def load(path: String): BinaryClassificationEvaluator = super.load(path)
} 
Example 60
Source File: MulticlassClassificationEvaluator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.types.DoubleType


  @Since("1.5.0")
  def setLabelCol(value: String): this.type = set(labelCol, value)

  setDefault(metricName -> "f1")

  @Since("1.5.0")
  override def evaluate(dataset: DataFrame): Double = {
    val schema = dataset.schema
    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)

    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
      .map { case Row(prediction: Double, label: Double) =>
      (prediction, label)
    }
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val metric = $(metricName) match {
      case "f1" => metrics.weightedFMeasure
      case "precision" => metrics.precision
      case "recall" => metrics.recall
      case "weightedPrecision" => metrics.weightedPrecision
      case "weightedRecall" => metrics.weightedRecall
    }
    metric
  }

  @Since("1.5.0")
  override def isLargerBetter: Boolean = $(metricName) match {
    case "f1" => true
    case "precision" => true
    case "recall" => true
    case "weightedPrecision" => true
    case "weightedRecall" => true
  }

  @Since("1.5.0")
  override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object MulticlassClassificationEvaluator
  extends DefaultParamsReadable[MulticlassClassificationEvaluator] {

  @Since("1.6.0")
  override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
} 
Example 61
Source File: RegressionEvaluator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType}


  @Since("1.4.0")
  def setLabelCol(value: String): this.type = set(labelCol, value)

  setDefault(metricName -> "rmse")

  @Since("1.4.0")
  override def evaluate(dataset: DataFrame): Double = {
    val schema = dataset.schema
    val predictionColName = $(predictionCol)
    val predictionType = schema($(predictionCol)).dataType
    require(predictionType == FloatType || predictionType == DoubleType,
      s"Prediction column $predictionColName must be of type float or double, " +
        s" but not $predictionType")
    val labelColName = $(labelCol)
    val labelType = schema($(labelCol)).dataType
    require(labelType == FloatType || labelType == DoubleType,
      s"Label column $labelColName must be of type float or double, but not $labelType")

    val predictionAndLabels = dataset
      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
      .map { case Row(prediction: Double, label: Double) =>
        (prediction, label)
      }
    val metrics = new RegressionMetrics(predictionAndLabels)
    val metric = $(metricName) match {
      case "rmse" => metrics.rootMeanSquaredError
      case "mse" => metrics.meanSquaredError
      case "r2" => metrics.r2
      case "mae" => metrics.meanAbsoluteError
    }
    metric
  }

  @Since("1.4.0")
  override def isLargerBetter: Boolean = $(metricName) match {
    case "rmse" => false
    case "mse" => false
    case "r2" => true
    case "mae" => false
  }

  @Since("1.5.0")
  override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {

  @Since("1.6.0")
  override def load(path: String): RegressionEvaluator = super.load(path)
} 
Example 62
Source File: LibSVMRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}


@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {

  @Since("1.6.0")
  override def shortName(): String = "libsvm"

  @Since("1.6.0")
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
    : BaseRelation = {
    val path = parameters.getOrElse("path",
      throw new IllegalArgumentException("'path' must be specified"))
    val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
    val vectorType = parameters.getOrElse("vectorType", "sparse")
    new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
  }
} 
Example 63
Source File: GLMClassificationModel.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.classification.impl

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.sql.{Row, SQLContext}


    def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
      val datapath = Loader.dataPath(path)
      val sqlContext = SQLContext.getOrCreate(sc)
      val dataRDD = sqlContext.read.parquet(datapath)
      val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
      assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
      val data = dataArray(0)
      assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
      val (weights, intercept) = data match {
        case Row(weights: Vector, intercept: Double, _) =>
          (weights, intercept)
      }
      val threshold = if (data.isNullAt(2)) {
        None
      } else {
        Some(data.getDouble(2))
      }
      Data(weights, intercept, threshold)
    }
  }

} 
Example 64
Source File: GLMRegressionModel.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.regression.impl

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


    def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
      val datapath = Loader.dataPath(path)
      val sqlContext = SQLContext.getOrCreate(sc)
      val dataRDD = sqlContext.read.parquet(datapath)
      val dataArray = dataRDD.select("weights", "intercept").take(1)
      assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
      val data = dataArray(0)
      assert(data.size == 2, s"Unable to load $modelClass data from: $datapath")
      data match {
        case Row(weights: Vector, intercept: Double) =>
          assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
            s" found ${weights.size} features when loading $modelClass weights from $datapath")
          Data(weights, intercept)
      }
    }
  }

} 
Example 65
Source File: ChiSqSelectorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}

class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
  with DefaultReadWriteTest {

  test("Test Chi-Square selector") {
    val sqlContext = SQLContext.getOrCreate(sc)
    import sqlContext.implicits._

    val data = Seq(
      LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
      LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
      LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
      LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
    )

    val preFilteredData = Seq(
      Vectors.dense(0.0),
      Vectors.dense(6.0),
      Vectors.dense(8.0),
      Vectors.dense(5.0)
    )

    val df = sc.parallelize(data.zip(preFilteredData))
      .map(x => (x._1.label, x._1.features, x._2))
      .toDF("label", "data", "preFilteredData")

    val model = new ChiSqSelector()
      .setNumTopFeatures(1)
      .setFeaturesCol("data")
      .setLabelCol("label")
      .setOutputCol("filtered")

    model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
      case Row(vec1: Vector, vec2: Vector) =>
        assert(vec1 ~== vec2 absTol 1e-1)
    }
  }

  test("ChiSqSelector read/write") {
    val t = new ChiSqSelector()
      .setFeaturesCol("myFeaturesCol")
      .setLabelCol("myLabelCol")
      .setOutputCol("myOutputCol")
      .setNumTopFeatures(2)
    testDefaultReadWrite(t)
  }

  test("ChiSqSelectorModel read/write") {
    val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
    val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.selectedFeatures === instance.selectedFeatures)
  }
} 
Example 66
Source File: DCTSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

@BeanInfo
case class DCTTestData(vec: Vector, wantedVec: Vector)

class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("forward transform of discrete cosine matches jTransforms result") {
    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    val inverse = false

    testDCT(data, inverse)
  }

  test("inverse transform of discrete cosine matches jTransforms result") {
    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    val inverse = true

    testDCT(data, inverse)
  }

  test("read/write") {
    val t = new DCT()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setInverse(true)
    testDefaultReadWrite(t)
  }

  private def testDCT(data: Vector, inverse: Boolean): Unit = {
    val expectedResultBuffer = data.toArray.clone()
    if (inverse) {
      (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true)
    } else {
      (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true)
    }
    val expectedResult = Vectors.dense(expectedResultBuffer)

    val dataset = sqlContext.createDataFrame(Seq(
      DCTTestData(data, expectedResult)
    ))

    val transformer = new DCT()
      .setInputCol("vec")
      .setOutputCol("resultVec")
      .setInverse(inverse)

    transformer.transform(dataset)
      .select("resultVec", "wantedVec")
      .collect()
      .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
      assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
    }
  }
} 
Example 67
Source File: StopWordsRemoverSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

object StopWordsRemoverSuite extends SparkFunSuite {
  def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
    t.transform(dataset)
      .select("filtered", "expected")
      .collect()
      .foreach { case Row(tokens, wantedTokens) =>
        assert(tokens === wantedTokens)
    }
  }
}

class StopWordsRemoverSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  import StopWordsRemoverSuite._

  test("StopWordsRemover default") {
    val remover = new StopWordsRemover()
      .setInputCol("raw")
      .setOutputCol("filtered")
    val dataSet = sqlContext.createDataFrame(Seq(
      (Seq("test", "test"), Seq("test", "test")),
      (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
      (Seq("a", "the", "an"), Seq()),
      (Seq("A", "The", "AN"), Seq()),
      (Seq(null), Seq(null)),
      (Seq(), Seq())
    )).toDF("raw", "expected")

    testStopWordsRemover(remover, dataSet)
  }

  test("StopWordsRemover case sensitive") {
    val remover = new StopWordsRemover()
      .setInputCol("raw")
      .setOutputCol("filtered")
      .setCaseSensitive(true)
    val dataSet = sqlContext.createDataFrame(Seq(
      (Seq("A"), Seq("A")),
      (Seq("The", "the"), Seq("The"))
    )).toDF("raw", "expected")

    testStopWordsRemover(remover, dataSet)
  }

  test("StopWordsRemover with additional words") {
    val stopWords = StopWords.English ++ Array("python", "scala")
    val remover = new StopWordsRemover()
      .setInputCol("raw")
      .setOutputCol("filtered")
      .setStopWords(stopWords)
    val dataSet = sqlContext.createDataFrame(Seq(
      (Seq("python", "scala", "a"), Seq()),
      (Seq("Python", "Scala", "swift"), Seq("swift"))
    )).toDF("raw", "expected")

    testStopWordsRemover(remover, dataSet)
  }

  test("read/write") {
    val t = new StopWordsRemover()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setStopWords(Array("the", "a"))
      .setCaseSensitive(true)
    testDefaultReadWrite(t)
  }

  test("StopWordsRemover output column already exists") {
    val outputCol = "expected"
    val remover = new StopWordsRemover()
      .setInputCol("raw")
      .setOutputCol(outputCol)
    val dataSet = sqlContext.createDataFrame(Seq(
      (Seq("The", "the", "swift"), Seq("swift"))
    )).toDF("raw", outputCol)

    val thrown = intercept[IllegalArgumentException] {
      testStopWordsRemover(remover, dataSet)
    }
    assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
  }
} 
Example 68
Source File: BinarizerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  @transient var data: Array[Double] = _

  override def beforeAll(): Unit = {
    super.beforeAll()
    data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
  }

  test("params") {
    ParamsSuite.checkParams(new Binarizer)
  }

  test("Binarize continuous features with default parameter") {
    val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
    val dataFrame: DataFrame = sqlContext.createDataFrame(
      data.zip(defaultBinarized)).toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")

    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
      case Row(x: Double, y: Double) =>
        assert(x === y, "The feature value is not correct after binarization.")
    }
  }

  test("Binarize continuous features with setter") {
    val threshold: Double = 0.2
    val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
    val dataFrame: DataFrame = sqlContext.createDataFrame(
        data.zip(thresholdBinarized)).toDF("feature", "expected")

    val binarizer: Binarizer = new Binarizer()
      .setInputCol("feature")
      .setOutputCol("binarized_feature")
      .setThreshold(threshold)

    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
      case Row(x: Double, y: Double) =>
        assert(x === y, "The feature value is not correct after binarization.")
    }
  }

  test("read/write") {
    val t = new Binarizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setThreshold(0.1)
    testDefaultReadWrite(t)
  }
} 
Example 69
Source File: TokenizerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])

class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("params") {
    ParamsSuite.checkParams(new Tokenizer)
  }

  test("read/write") {
    val t = new Tokenizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
    testDefaultReadWrite(t)
  }
}

class RegexTokenizerSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  import org.apache.spark.ml.feature.RegexTokenizerSuite._

  test("params") {
    ParamsSuite.checkParams(new RegexTokenizer)
  }

  test("RegexTokenizer") {
    val tokenizer0 = new RegexTokenizer()
      .setGaps(false)
      .setPattern("\\w+|\\p{Punct}")
      .setInputCol("rawText")
      .setOutputCol("tokens")
    val dataset0 = sqlContext.createDataFrame(Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
      TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
    ))
    testRegexTokenizer(tokenizer0, dataset0)

    val dataset1 = sqlContext.createDataFrame(Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
      TokenizerTestData("Te,st. punct", Array("punct"))
    ))
    tokenizer0.setMinTokenLength(3)
    testRegexTokenizer(tokenizer0, dataset1)

    val tokenizer2 = new RegexTokenizer()
      .setInputCol("rawText")
      .setOutputCol("tokens")
    val dataset2 = sqlContext.createDataFrame(Seq(
      TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
      TokenizerTestData("Te,st.  punct", Array("te,st.", "punct"))
    ))
    testRegexTokenizer(tokenizer2, dataset2)
  }

  test("RegexTokenizer with toLowercase false") {
    val tokenizer = new RegexTokenizer()
      .setInputCol("rawText")
      .setOutputCol("tokens")
      .setToLowercase(false)
    val dataset = sqlContext.createDataFrame(Seq(
      TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
      TokenizerTestData("java scala", Array("java", "scala"))
    ))
    testRegexTokenizer(tokenizer, dataset)
  }

  test("read/write") {
    val t = new RegexTokenizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMinTokenLength(2)
      .setGaps(false)
      .setPattern("hi")
      .setToLowercase(false)
    testDefaultReadWrite(t)
  }
}

object RegexTokenizerSuite extends SparkFunSuite {

  def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
    t.transform(dataset)
      .select("tokens", "wantedTokens")
      .collect()
      .foreach { case Row(tokens, wantedTokens) =>
        assert(tokens === wantedTokens)
      }
  }
} 
Example 70
Source File: MinMaxScalerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}

class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("MinMaxScaler fit basic case") {
    val sqlContext = new SQLContext(sc)

    val data = Array(
      Vectors.dense(1, 0, Long.MinValue),
      Vectors.dense(2, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)),
      Vectors.sparse(3, Array(0), Array(1.5)))

    val expected: Array[Vector] = Array(
      Vectors.dense(-5, 0, -5),
      Vectors.dense(0, 0, 0),
      Vectors.sparse(3, Array(0, 2), Array(5, 5)),
      Vectors.sparse(3, Array(0), Array(-2.5)))

    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("scaled")
      .setMin(-5)
      .setMax(5)

    val model = scaler.fit(df)
    model.transform(df).select("expected", "scaled").collect()
      .foreach { case Row(vector1: Vector, vector2: Vector) =>
        assert(vector1.equals(vector2), "Transformed vector is different with expected.")
    }

    // copied model must have the same parent.
    MLTestingUtils.checkCopy(model)
  }

  test("MinMaxScaler arguments max must be larger than min") {
    withClue("arguments max must be larger than min") {
      intercept[IllegalArgumentException] {
        val scaler = new MinMaxScaler().setMin(10).setMax(0)
        scaler.validateParams()
      }
      intercept[IllegalArgumentException] {
        val scaler = new MinMaxScaler().setMin(0).setMax(0)
        scaler.validateParams()
      }
    }
  }

  test("MinMaxScaler read/write") {
    val t = new MinMaxScaler()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMax(1.0)
      .setMin(-1.0)
    testDefaultReadWrite(t)
  }

  test("MinMaxScalerModel read/write") {
    val instance = new MinMaxScalerModel(
        "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0))
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMin(-1.0)
      .setMax(1.0)
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.originalMin === instance.originalMin)
    assert(newInstance.originalMax === instance.originalMax)
  }
} 
Example 71
Source File: PolynomialExpansionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class PolynomialExpansionSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("params") {
    ParamsSuite.checkParams(new PolynomialExpansion)
  }

  test("Polynomial expansion with default parameter") {
    val data = Array(
      Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
      Vectors.dense(-2.0, 2.3),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.6, -1.1, -3.0),
      Vectors.sparse(3, Seq())
    )

    val twoDegreeExpansion: Array[Vector] = Array(
      Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)),
      Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29),
      Vectors.dense(new Array[Double](9)),
      Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0),
      Vectors.sparse(9, Array.empty, Array.empty))

    val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")

    val polynomialExpansion = new PolynomialExpansion()
      .setInputCol("features")
      .setOutputCol("polyFeatures")

    polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
      case Row(expanded: DenseVector, expected: DenseVector) =>
        assert(expanded ~== expected absTol 1e-1)
      case Row(expanded: SparseVector, expected: SparseVector) =>
        assert(expanded ~== expected absTol 1e-1)
      case _ =>
        throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
    }
  }

  test("Polynomial expansion with setter") {
    val data = Array(
      Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
      Vectors.dense(-2.0, 2.3),
      Vectors.dense(0.0, 0.0, 0.0),
      Vectors.dense(0.6, -1.1, -3.0),
      Vectors.sparse(3, Seq())
    )

    val threeDegreeExpansion: Array[Vector] = Array(
      Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8),
        Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)),
      Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17),
      Vectors.dense(new Array[Double](19)),
      Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8,
        -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0),
      Vectors.sparse(19, Array.empty, Array.empty))

    val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")

    val polynomialExpansion = new PolynomialExpansion()
      .setInputCol("features")
      .setOutputCol("polyFeatures")
      .setDegree(3)

    polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
      case Row(expanded: DenseVector, expected: DenseVector) =>
        assert(expanded ~== expected absTol 1e-1)
      case Row(expanded: SparseVector, expected: SparseVector) =>
        assert(expanded ~== expected absTol 1e-1)
      case _ =>
        throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
    }
  }

  test("read/write") {
    val t = new PolynomialExpansion()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setDegree(3)
    testDefaultReadWrite(t)
  }
} 
Example 72
Source File: IDFSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
    dataSet.map {
      case data: DenseVector =>
        val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
        Vectors.dense(res)
      case data: SparseVector =>
        val res = data.indices.zip(data.values).map { case (id, value) =>
          (id, value * model(id))
        }
        Vectors.sparse(data.size, res)
    }
  }

  test("params") {
    ParamsSuite.checkParams(new IDF)
    val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
    ParamsSuite.checkParams(model)
  }

  test("compute IDF with default parameter") {
    val numOfFeatures = 4
    val data = Array(
      Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
      Vectors.dense(0.0, 1.0, 2.0, 3.0),
      Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
    )
    val numOfData = data.size
    val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      math.log((numOfData + 1.0) / (x + 1.0))
    })
    val expected = scaleDataWithIDF(data, idf)

    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")

    val idfModel = new IDF()
      .setInputCol("features")
      .setOutputCol("idfValue")
      .fit(df)

    idfModel.transform(df).select("idfValue", "expected").collect().foreach {
      case Row(x: Vector, y: Vector) =>
        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
    }
  }

  test("compute IDF with setter") {
    val numOfFeatures = 4
    val data = Array(
      Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
      Vectors.dense(0.0, 1.0, 2.0, 3.0),
      Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
    )
    val numOfData = data.size
    val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
    })
    val expected = scaleDataWithIDF(data, idf)

    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")

    val idfModel = new IDF()
      .setInputCol("features")
      .setOutputCol("idfValue")
      .setMinDocFreq(1)
      .fit(df)

    idfModel.transform(df).select("idfValue", "expected").collect().foreach {
      case Row(x: Vector, y: Vector) =>
        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
    }
  }

  test("IDF read/write") {
    val t = new IDF()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setMinDocFreq(5)
    testDefaultReadWrite(t)
  }

  test("IDFModel read/write") {
    val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0)))
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.idf === instance.idf)
  }
} 
Example 73
Source File: NGramSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}

@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])

class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
  import org.apache.spark.ml.feature.NGramSuite._

  test("default behavior yields bigram features") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
    val dataset = sqlContext.createDataFrame(Seq(
      NGramTestData(
        Array("Test", "for", "ngram", "."),
        Array("Test for", "for ngram", "ngram .")
    )))
    testNGram(nGram, dataset)
  }

  test("NGramLength=4 yields length 4 n-grams") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(4)
    val dataset = sqlContext.createDataFrame(Seq(
      NGramTestData(
        Array("a", "b", "c", "d", "e"),
        Array("a b c d", "b c d e")
      )))
    testNGram(nGram, dataset)
  }

  test("empty input yields empty output") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(4)
    val dataset = sqlContext.createDataFrame(Seq(
      NGramTestData(
        Array(),
        Array()
      )))
    testNGram(nGram, dataset)
  }

  test("input array < n yields empty output") {
    val nGram = new NGram()
      .setInputCol("inputTokens")
      .setOutputCol("nGrams")
      .setN(6)
    val dataset = sqlContext.createDataFrame(Seq(
      NGramTestData(
        Array("a", "b", "c", "d", "e"),
        Array()
      )))
    testNGram(nGram, dataset)
  }

  test("read/write") {
    val t = new NGram()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setN(3)
    testDefaultReadWrite(t)
  }
}

object NGramSuite extends SparkFunSuite {

  def testNGram(t: NGram, dataset: DataFrame): Unit = {
    t.transform(dataset)
      .select("nGrams", "wantedNGrams")
      .collect()
      .foreach { case Row(actualNGrams, wantedNGrams) =>
        assert(actualNGrams === wantedNGrams)
      }
  }
} 
Example 74
Source File: PCASuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
import org.apache.spark.sql.Row

class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  test("params") {
    ParamsSuite.checkParams(new PCA)
    val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
    val model = new PCAModel("pca", mat)
    ParamsSuite.checkParams(model)
  }

  test("pca") {
    val data = Array(
      Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
    )

    val dataRDD = sc.parallelize(data, 2)

    val mat = new RowMatrix(dataRDD)
    val pc = mat.computePrincipalComponents(3)
    val expected = mat.multiply(pc).rows

    val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")

    val pca = new PCA()
      .setInputCol("features")
      .setOutputCol("pca_features")
      .setK(3)
      .fit(df)

    // copied model must have the same parent.
    MLTestingUtils.checkCopy(pca)

    pca.transform(df).select("pca_features", "expected").collect().foreach {
      case Row(x: Vector, y: Vector) =>
        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
    }
  }

  test("PCA read/write") {
    val t = new PCA()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setK(3)
    testDefaultReadWrite(t)
  }

  test("PCAModel read/write") {
    val instance = new PCAModel("myPCAModel",
      Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
    val newInstance = testDefaultReadWrite(instance)
    assert(newInstance.pc === instance.pc)
  }
} 
Example 75
Source File: QuantileDiscretizerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkFunSuite}

class QuantileDiscretizerSuite
  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

  import org.apache.spark.ml.feature.QuantileDiscretizerSuite._

  test("Test quantile discretizer") {
    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      10,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      4,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      3,
      Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
      Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))

    checkDiscretizedData(sc,
      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
      2,
      Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
      Array("-Infinity, 2.0", "2.0, Infinity"))

  }

  test("Test getting splits") {
    val splitTestPoints = Array(
      Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(Double.NegativeInfinity, Double.PositiveInfinity)
        -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
      Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
      Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
    )
    for ((ori, res) <- splitTestPoints) {
      assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
    }
  }

  test("read/write") {
    val t = new QuantileDiscretizer()
      .setInputCol("myInputCol")
      .setOutputCol("myOutputCol")
      .setNumBuckets(6)
    testDefaultReadWrite(t)
  }
}

private object QuantileDiscretizerSuite extends SparkFunSuite {

  def checkDiscretizedData(
      sc: SparkContext,
      data: Array[Double],
      numBucket: Int,
      expectedResult: Array[Double],
      expectedAttrs: Array[String]): Unit = {
    val sqlCtx = SQLContext.getOrCreate(sc)
    import sqlCtx.implicits._

    val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
    val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
      .setNumBuckets(numBucket)
    val result = discretizer.fit(df).transform(df)

    val transformedFeatures = result.select("result").collect()
      .map { case Row(transformedFeature: Double) => transformedFeature }
    val transformedAttrs = Attribute.fromStructField(result.schema("result"))
      .asInstanceOf[NominalAttribute].values.get

    assert(transformedFeatures === expectedResult,
      "Transformed features do not equal expected features.")
    assert(transformedAttrs === expectedAttrs,
      "Transformed attributes do not equal expected attributes.")
  }
} 
Example 76
Source File: MultilayerPerceptronClassifierSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.classification

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("XOR function learning as binary classification problem with two outputs.") {
    val dataFrame = sqlContext.createDataFrame(Seq(
        (Vectors.dense(0.0, 0.0), 0.0),
        (Vectors.dense(0.0, 1.0), 1.0),
        (Vectors.dense(1.0, 0.0), 1.0),
        (Vectors.dense(1.0, 1.0), 0.0))
    ).toDF("features", "label")
    val layers = Array[Int](2, 5, 2)
    val trainer = new MultilayerPerceptronClassifier()
      .setLayers(layers)
      .setBlockSize(1)
      .setSeed(11L)
      .setMaxIter(100)
    val model = trainer.fit(dataFrame)
    val result = model.transform(dataFrame)
    val predictionAndLabels = result.select("prediction", "label").collect()
    predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
      assert(p == l)
    }
  }

  // TODO: implement a more rigorous test
  test("3 class classification with 2 hidden layers") {
    val nPoints = 1000

    // The following coefficients are taken from OneVsRestSuite.scala
    // they represent 3-class iris dataset
    val coefficients = Array(
      -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
      -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)

    val xMean = Array(5.843, 3.057, 3.758, 1.199)
    val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
    // the input seed is somewhat magic, to make this test pass
    val rdd = sc.parallelize(generateMultinomialLogisticInput(
      coefficients, xMean, xVariance, true, nPoints, 1), 2)
    val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
    val numClasses = 3
    val numIterations = 100
    val layers = Array[Int](4, 5, 4, numClasses)
    val trainer = new MultilayerPerceptronClassifier()
      .setLayers(layers)
      .setBlockSize(1)
      .setSeed(11L) // currently this seed is ignored
      .setMaxIter(numIterations)
    val model = trainer.fit(dataFrame)
    val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
    assert(model.numFeatures === numFeatures)
    val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
      .map { case Row(p: Double, l: Double) => (p, l) }
    // train multinomial logistic regression
    val lr = new LogisticRegressionWithLBFGS()
      .setIntercept(true)
      .setNumClasses(numClasses)
    lr.optimizer.setRegParam(0.0)
      .setNumIterations(numIterations)
    val lrModel = lr.run(rdd)
    val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label))
    // MLP's predictions should not differ a lot from LR's.
    val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
    val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
    assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
  }
} 
Example 77
Source File: DescribeHiveTableCommand.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.collection.JavaConverters._

import org.apache.hadoop.hive.metastore.api.FieldSchema

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.MetastoreRelation
import org.apache.spark.sql.{Row, SQLContext}


private[hive]
case class DescribeHiveTableCommand(
    table: MetastoreRelation,
    override val output: Seq[Attribute],
    isExtended: Boolean) extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    // Trying to mimic the format of Hive's output. But not exactly the same.
    var results: Seq[(String, String, String)] = Nil

    val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala
    val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala
    results ++= columns.map(field => (field.getName, field.getType, field.getComment))
    if (partitionColumns.nonEmpty) {
      val partColumnInfo =
        partitionColumns.map(field => (field.getName, field.getType, field.getComment))
      results ++=
        partColumnInfo ++
          Seq(("# Partition Information", "", "")) ++
          Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++
          partColumnInfo
    }

    if (isExtended) {
      results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, ""))
    }

    results.map { case (name, dataType, comment) =>
      Row(name, dataType, comment)
    }
  }
} 
Example 78
Source File: CreateViewAsSelect.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext}
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable}


// TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is different
// from Hive and may not work for some cases like create view on self join.
private[hive] case class CreateViewAsSelect(
    tableDesc: HiveTable,
    childSchema: Seq[Attribute],
    allowExisting: Boolean,
    orReplace: Boolean) extends RunnableCommand {

  assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length)
  assert(tableDesc.viewText.isDefined)

  val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database))

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val hiveContext = sqlContext.asInstanceOf[HiveContext]

    if (hiveContext.catalog.tableExists(tableIdentifier)) {
      if (allowExisting) {
        // view already exists, will do nothing, to keep consistent with Hive
      } else if (orReplace) {
        hiveContext.catalog.client.alertView(prepareTable())
      } else {
        throw new AnalysisException(s"View $tableIdentifier already exists. " +
          "If you want to update the view definition, please use ALTER VIEW AS or " +
          "CREATE OR REPLACE VIEW AS")
      }
    } else {
      hiveContext.catalog.client.createView(prepareTable())
    }

    Seq.empty[Row]
  }

  private def prepareTable(): HiveTable = {
    // setup column types according to the schema of child.
    val schema = if (tableDesc.schema == Nil) {
      childSchema.map { attr =>
        HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null)
      }
    } else {
      childSchema.zip(tableDesc.schema).map { case (attr, col) =>
        HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment)
      }
    }

    val columnNames = childSchema.map(f => verbose(f.name))

    // When user specified column names for view, we should create a project to do the renaming.
    // When no column name specified, we still need to create a project to declare the columns
    // we need, to make us more robust to top level `*`s.
    val projectList = if (tableDesc.schema == Nil) {
      columnNames.mkString(", ")
    } else {
      columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map {
        case (name, alias) => s"$name AS $alias"
      }.mkString(", ")
    }

    val viewName = verbose(tableDesc.name)

    val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName"

    tableDesc.copy(schema = schema, viewText = Some(expandedText))
  }

  // escape backtick with double-backtick in column name and wrap it with backtick.
  private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`"
} 
Example 79
Source File: CreateTableAsSelect.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable}
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation}
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}


private[hive]
case class CreateTableAsSelect(
    tableDesc: HiveTable,
    query: LogicalPlan,
    allowExisting: Boolean)
  extends RunnableCommand {

  val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database))

  override def children: Seq[LogicalPlan] = Seq(query)

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val hiveContext = sqlContext.asInstanceOf[HiveContext]
    lazy val metastoreRelation: MetastoreRelation = {
      import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
      import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
      import org.apache.hadoop.io.Text
      import org.apache.hadoop.mapred.TextInputFormat

      val withFormat =
        tableDesc.copy(
          inputFormat =
            tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)),
          outputFormat =
            tableDesc.outputFormat
              .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)),
          serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName())))

      val withSchema = if (withFormat.schema.isEmpty) {
        // Hive doesn't support specifying the column list for target table in CTAS
        // However we don't think SparkSQL should follow that.
        tableDesc.copy(schema =
        query.output.map(c =>
          HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null)))
      } else {
        withFormat
      }

      hiveContext.catalog.client.createTable(withSchema)

      // Get the Metastore Relation
      hiveContext.catalog.lookupRelation(tableIdentifier, None) match {
        case r: MetastoreRelation => r
      }
    }
    // TODO ideally, we should get the output data ready first and then
    // add the relation into catalog, just in case of failure occurs while data
    // processing.
    if (hiveContext.catalog.tableExists(tableIdentifier)) {
      if (allowExisting) {
        // table already exists, will do nothing, to keep consistent with Hive
      } else {
        throw new AnalysisException(s"$tableIdentifier already exists.")
      }
    } else {
      hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd
    }

    Seq.empty[Row]
  }

  override def argString: String = {
    s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name}, InsertIntoHiveTable]"
  }
} 
Example 80
Source File: HiveDataFrameAnalyticsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.scalatest.BeforeAndAfterAll

// TODO ideally we should put the test suite into the package `sql`, as
// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
// support the `cube` or `rollup` yet.
class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
  import hiveContext.implicits._
  import hiveContext.sql

  private var testData: DataFrame = _

  override def beforeAll() {
    testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
    hiveContext.registerDataFrameAsTable(testData, "mytable")
  }

  override def afterAll(): Unit = {
    hiveContext.dropTempTable("mytable")
  }

  test("rollup") {
    checkAnswer(
      testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
      sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
    )

    checkAnswer(
      testData.rollup("a", "b").agg(sum("b")),
      sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
    )
  }

  test("collect functions") {
    checkAnswer(
      testData.select(collect_list($"a"), collect_list($"b")),
      Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
    )
    checkAnswer(
      testData.select(collect_set($"a"), collect_set($"b")),
      Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
    )
  }

  test("cube") {
    checkAnswer(
      testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
      sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
    )

    checkAnswer(
      testData.cube("a", "b").agg(sum("b")),
      sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
    )
  }
} 
Example 81
Source File: OrcHadoopFsRelationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.orc

import org.apache.hadoop.fs.Path

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.{Row, SQLConf}
import org.apache.spark.sql.sources.HadoopFsRelationTest
import org.apache.spark.sql.types._

class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
  import testImplicits._

  override val dataSourceName: String = classOf[DefaultSource].getCanonicalName

  // ORC does not play well with NullType and UDT.
  override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
    case _: NullType => false
    case _: CalendarIntervalType => false
    case _: UserDefinedType[_] => false
    case _ => true
  }

  test("save()/load() - partitioned table - simple queries - partition columns in data") {
    withTempDir { file =>
      val basePath = new Path(file.getCanonicalPath)
      val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
      val qualifiedBasePath = fs.makeQualified(basePath)

      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
        sparkContext
          .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
          .toDF("a", "b", "p1")
          .write
          .orc(partitionDir.toString)
      }

      val dataSchemaWithPartition =
        StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

      checkQueries(
        hiveContext.read.options(Map(
          "path" -> file.getCanonicalPath,
          "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load())
    }
  }

  test("SPARK-12218: 'Not' is included in ORC filter pushdown") {
    import testImplicits._

    withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
      withTempPath { dir =>
        val path = s"${dir.getCanonicalPath}/table1"
        (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path)

        checkAnswer(
          sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"),
          (1 to 5).map(i => Row(i, (i % 2).toString)))

        checkAnswer(
          sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"),
          (1 to 5).map(i => Row(i, (i % 2).toString)))
      }
    }
  }
} 
Example 82
Source File: HiveParquetSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.{QueryTest, Row}

case class Cases(lower: String, UPPER: String)

class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton {

  test("Case insensitive attribute names") {
    withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
      val expected = (1 to 4).map(i => Row(i.toString))
      checkAnswer(sql("SELECT upper FROM cases"), expected)
      checkAnswer(sql("SELECT LOWER FROM cases"), expected)
    }
  }

  test("SELECT on Parquet table") {
    val data = (1 to 4).map(i => (i, s"val_$i"))
    withParquetTable(data, "t") {
      checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple))
    }
  }

  test("Simple column projection + filter on Parquet table") {
    withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") {
      checkAnswer(
        sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"),
        Seq(Row(true, "val_2"), Row(true, "val_4")))
    }
  }

  test("Converting Hive to Parquet Table via saveAsParquetFile") {
    withTempPath { dir =>
      sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
      hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p")
      withTempTable("p") {
        checkAnswer(
          sql("SELECT * FROM src ORDER BY key"),
          sql("SELECT * from p ORDER BY key").collect().toSeq)
      }
    }
  }

  test("INSERT OVERWRITE TABLE Parquet table") {
    withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
      withTempPath { file =>
        sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
        hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p")
        withTempTable("p") {
          // let's do three overwrites for good measure
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq)
        }
      }
    }
  }
} 
Example 83
Source File: HiveDataFrameJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.hive.test.TestHiveSingleton

class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton {
  import hiveContext.implicits._

  // We should move this into SQL package if we make case sensitivity configurable in SQL.
  test("join - self join auto resolve ambiguity with case insensitivity") {
    val df = Seq((1, "1"), (2, "2")).toDF("key", "value")
    checkAnswer(
      df.join(df, df("key") === df("Key")),
      Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil)

    checkAnswer(
      df.join(df.filter($"value" === "2"), df("key") === df("Key")),
      Row(2, "2", 2, "2") :: Nil)
  }

} 
Example 84
Source File: HiveTableScanSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._

import org.apache.spark.util.Utils

class HiveTableScanSuite extends HiveComparisonTest {

  createQueryTest("partition_based_table_scan_with_different_serde",
    """
      |CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING)
      |ROW FORMAT SERDE
      |'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
      |STORED AS RCFILE;
      |
      |FROM src
      |INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01')
      |SELECT 100,100 LIMIT 1;
      |
      |ALTER TABLE part_scan_test SET SERDE
      |'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe';
      |
      |FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02')
      |SELECT 200,200 LIMIT 1;
      |
      |SELECT * from part_scan_test;
    """.stripMargin)

  // In unit test, kv1.txt is a small file and will be loaded as table src
  // Since the small file will be considered as a single split, we assume
  // Hive / SparkSQL HQL has the same output even for SORT BY
  createQueryTest("file_split_for_small_table",
    """
      |SELECT key, value FROM src SORT BY key, value
    """.stripMargin)

  test("Spark-4041: lowercase issue") {
    TestHive.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC")
    TestHive.sql("insert into table tb select key, value from src")
    TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect()
    TestHive.sql("drop table tb")
  }

  test("Spark-4077: timestamp query for null value") {
    TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null")
    TestHive.sql(
      """
        CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT)
        ROW FORMAT DELIMITED
        FIELDS TERMINATED BY ','
        LINES TERMINATED BY '\n'
      """.stripMargin)
    val location =
      Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile()

    TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null")
    assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect()
      === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null)))
    TestHive.sql("DROP TABLE timestamp_query_null")
  }

  test("Spark-4959 Attributes are case sensitive when using a select query from a projection") {
    sql("create table spark_4959 (col1 string)")
    sql("""insert into table spark_4959 select "hi" from src limit 1""")
    table("spark_4959").select(
      'col1.as("CaseSensitiveColName"),
      'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2")

    assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
    assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
  }
} 
Example 85
Source File: HiveOperatorQueryableSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton}


class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton {
  import hiveContext._

  test("SPARK-5324 query result of describe command") {
    hiveContext.loadTestTable("src")

    // register a describe command to be a temp table
    sql("desc src").registerTempTable("mydesc")
    checkAnswer(
      sql("desc mydesc"),
      Seq(
        Row("col_name", "string", "name of the column"),
        Row("data_type", "string", "data type of the column"),
        Row("comment", "string", "comment of the column")))

    checkAnswer(
      sql("select * from mydesc"),
      Seq(
        Row("key", "int", null),
        Row("value", "string", null)))

    checkAnswer(
      sql("select col_name, data_type, comment from mydesc"),
      Seq(
        Row("key", "int", null),
        Row("value", "string", null)))
  }
} 
Example 86
Source File: ListTablesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row

class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
  import hiveContext._
  import hiveContext.implicits._

  val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")

  override def beforeAll(): Unit = {
    // The catalog in HiveContext is a case insensitive one.
    catalog.registerTable(TableIdentifier("ListTablesSuiteTable"), df.logicalPlan)
    sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
    sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
    sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
  }

  override def afterAll(): Unit = {
    catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
    sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
    sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
    sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
  }

  test("get all tables of current database") {
    Seq(tables(), sql("SHOW TABLes")).foreach {
      case allTables =>
        // We are using default DB.
        checkAnswer(
          allTables.filter("tableName = 'listtablessuitetable'"),
          Row("listtablessuitetable", true))
        checkAnswer(
          allTables.filter("tableName = 'hivelisttablessuitetable'"),
          Row("hivelisttablessuitetable", false))
        assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
    }
  }

  test("getting all tables with a database name") {
    Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach {
      case allTables =>
        checkAnswer(
          allTables.filter("tableName = 'listtablessuitetable'"),
          Row("listtablessuitetable", true))
        assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
        checkAnswer(
          allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
          Row("hiveindblisttablessuitetable", false))
    }
  }
} 
Example 87
Source File: JsonHadoopFsRelationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.math.BigDecimal

import org.apache.hadoop.fs.Path

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
  override val dataSourceName: String = "json"

  // JSON does not write data of NullType and does not play well with BinaryType.
  override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
    case _: NullType => false
    case _: BinaryType => false
    case _: CalendarIntervalType => false
    case _ => true
  }

  test("save()/load() - partitioned table - simple queries - partition columns in data") {
    withTempDir { file =>
      val basePath = new Path(file.getCanonicalPath)
      val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
      val qualifiedBasePath = fs.makeQualified(basePath)

      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
        sparkContext
          .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""")
          .saveAsTextFile(partitionDir.toString)
      }

      val dataSchemaWithPartition =
        StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

      checkQueries(
        hiveContext.read.format(dataSourceName)
          .option("dataSchema", dataSchemaWithPartition.json)
          .load(file.getCanonicalPath))
    }
  }

  test("SPARK-9894: save complex types to JSON") {
    withTempDir { file =>
      file.delete()

      val schema =
        new StructType()
          .add("array", ArrayType(LongType))
          .add("map", MapType(StringType, new StructType().add("innerField", LongType)))

      val data =
        Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) ::
          Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil
      val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema)

      // Write the data out.
      df.write.format(dataSourceName).save(file.getCanonicalPath)

      // Read it back and check the result.
      checkAnswer(
        hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
        df
      )
    }
  }

  test("SPARK-10196: save decimal type to JSON") {
    withTempDir { file =>
      file.delete()

      val schema =
        new StructType()
          .add("decimal", DecimalType(7, 2))

      val data =
        Row(new BigDecimal("10.02")) ::
          Row(new BigDecimal("20000.99")) ::
          Row(new BigDecimal("10000")) :: Nil
      val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema)

      // Write the data out.
      df.write.format(dataSourceName).save(file.getCanonicalPath)

      // Read it back and check the result.
      checkAnswer(
        hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
        df
      )
    }
  }
} 
Example 88
Source File: LocalRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
  def apply(output: Attribute*): LocalRelation = new LocalRelation(output)

  def apply(output1: StructField, output: StructField*): LocalRelation = {
    new LocalRelation(StructType(output1 +: output).toAttributes)
  }

  def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }

  def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }
}

case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
  extends LeafNode with analysis.MultiInstanceRelation {

  
  override final def newInstance(): this.type = {
    LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
  }

  override protected def stringArgs = Iterator(output)

  override def sameResult(plan: LogicalPlan): Boolean = plan match {
    case LocalRelation(otherOutput, otherData) =>
      otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
    case _ => false
  }

  override lazy val statistics =
    Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length)
} 
Example 89
Source File: CatalystTypeConvertersSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(
    StringType,
    DateType,
    BooleanType,
    ByteType,
    ShortType,
    IntegerType,
    LongType,
    FloatType,
    DoubleType,
    DecimalType.SYSTEM_DEFAULT,
    DecimalType.USER_DEFAULT)

  test("null handling in rows") {
    val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
  }

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
    }
  }

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
  }

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
  }
} 
Example 90
Source File: JDBCRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}


  def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
    if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))

    val numPartitions = partitioning.numPartitions
    val column = partitioning.column
    if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
    // Overflow and silliness can happen if you subtract then divide.
    // Here we get a little roundoff, but that's (hopefully) OK.
    val stride: Long = (partitioning.upperBound / numPartitions
                      - partitioning.lowerBound / numPartitions)
    var i: Int = 0
    var currentValue: Long = partitioning.lowerBound
    var ans = new ArrayBuffer[Partition]()
    while (i < numPartitions) {
      val lowerBound = if (i != 0) s"$column >= $currentValue" else null
      currentValue += stride
      val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
      val whereClause =
        if (upperBound == null) {
          lowerBound
        } else if (lowerBound == null) {
          upperBound
        } else {
          s"$lowerBound AND $upperBound"
        }
      ans += JDBCPartition(whereClause, i)
      i = i + 1
    }
    ans.toArray
  }
}

private[sql] case class JDBCRelation(
    url: String,
    table: String,
    parts: Array[Partition],
    properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedFilteredScan
  with InsertableRelation {

  override val needConversion: Boolean = false

  override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    JDBCRDD.scanTable(
      sqlContext.sparkContext,
      schema,
      url,
      properties,
      table,
      requiredColumns,
      filters,
      parts).asInstanceOf[RDD[Row]]
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    data.write
      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
      .jdbc(url, table, properties)
  }
} 
Example 91
Source File: JacksonGenerator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils}

import scala.collection.Map

import com.fasterxml.jackson.core._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

private[sql] object JacksonGenerator {
  
  def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = {
    def valWriter: (DataType, Any) => Unit = {
      case (_, null) | (NullType, _) => gen.writeNull()
      case (StringType, v) => gen.writeString(v.toString)
      case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString)
      case (IntegerType, v: Int) => gen.writeNumber(v)
      case (ShortType, v: Short) => gen.writeNumber(v)
      case (FloatType, v: Float) => gen.writeNumber(v)
      case (DoubleType, v: Double) => gen.writeNumber(v)
      case (LongType, v: Long) => gen.writeNumber(v)
      case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal)
      case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
      case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
      case (BooleanType, v: Boolean) => gen.writeBoolean(v)
      case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString)
      // For UDT values, they should be in the SQL type's corresponding value type.
      // We should not see values in the user-defined class at here.
      // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is
      // an ArrayData at here, instead of a Vector.
      case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)

      case (ArrayType(ty, _), v: ArrayData) =>
        gen.writeStartArray()
        v.foreach(ty, (_, value) => valWriter(ty, value))
        gen.writeEndArray()

      case (MapType(kt, vt, _), v: MapData) =>
        gen.writeStartObject()
        v.foreach(kt, vt, { (k, v) =>
          gen.writeFieldName(k.toString)
          valWriter(vt, v)
        })
        gen.writeEndObject()

      case (StructType(ty), v: InternalRow) =>
        gen.writeStartObject()
        var i = 0
        while (i < ty.length) {
          val field = ty(i)
          val value = v.get(i, field.dataType)
          if (value != null) {
            gen.writeFieldName(field.name)
            valWriter(field.dataType, value)
          }
          i += 1
        }
        gen.writeEndObject()

      case (dt, v) =>
        sys.error(
          s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.")
    }

    valWriter(rowSchema, row)
  }
} 
Example 92
Source File: FrequentItems.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, Column, DataFrame}

private[sql] object FrequentItems extends Logging {

  
  private[sql] def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 93
Source File: ExistingRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}


object RDDConversions {
  def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
    data.mapPartitions { iterator =>
      val numColumns = outputTypes.length
      val mutableRow = new GenericMutableRow(numColumns)
      val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
      iterator.map { r =>
        var i = 0
        while (i < numColumns) {
          mutableRow(i) = converters(i)(r.productElement(i))
          i += 1
        }

        mutableRow
      }
    }
  }

  
//private[sql]
case class PhysicalRDD(
    output: Seq[Attribute],
    rdd: RDD[InternalRow],
    override val nodeName: String,
    override val metadata: Map[String, String] = Map.empty,
    override val outputsUnsafeRows: Boolean = false)
  extends LeafNode {

  protected override def doExecute(): RDD[InternalRow] = rdd

  override def simpleString: String = {
    val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value"
    s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}"
  }
}

private[sql] object PhysicalRDD {
  // Metadata keys
  val INPUT_PATHS = "InputPaths"
  val PUSHED_FILTERS = "PushedFilters"

  def createFromDataSource(
      output: Seq[Attribute],
      rdd: RDD[InternalRow],
      relation: BaseRelation,
      metadata: Map[String, String] = Map.empty): PhysicalRDD = {
    // All HadoopFsRelations output UnsafeRows
    val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation]
    PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
  }
} 
Example 94
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 95
Source File: TextSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.util.Utils


class TextSuite extends QueryTest with SharedSQLContext {

  test("reading text file") {
    verifyFrame(sqlContext.read.format("text").load(testFile))
  }

  test("SQLContext.read.text() API") {
    verifyFrame(sqlContext.read.text(testFile))
  }

  test("SPARK-12562 verify write.text() can handle column name beyond `value`") {
    val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")

    val tempFile = Utils.createTempDir()
    tempFile.delete()
    df.write.text(tempFile.getCanonicalPath)
    verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))

    Utils.deleteRecursively(tempFile)
  }

  test("error handling for invalid schema") {
    val tempFile = Utils.createTempDir()
    tempFile.delete()

    val df = sqlContext.range(2)
    intercept[AnalysisException] {
      df.write.text(tempFile.getCanonicalPath)
    }

    intercept[AnalysisException] {
      sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
    }
  }

  private def testFile: String = {
    Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
  }

  
  private def verifyFrame(df: DataFrame): Unit = {
    // schema
    assert(df.schema == new StructType().add("value", StringType))

    // verify content
    val data = df.collect()
    assert(data(0) == Row("This is a test file for the text data source"))
    assert(data(1) == Row("1+1"))
    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
    // scalastyle:off
    assert(data(2) == Row("数据砖头"))
    // scalastyle:on
    assert(data(3) == Row("\"doh\""))
    assert(data.length == 4)
  }
} 
Example 96
Source File: ParquetInteroperabilitySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("parquet files with different physical schemas but share the same logical schema") {
    import ParquetCompatibilityTest._

    // This test case writes two Parquet files, both representing the following Catalyst schema
    //
    //   StructType(
    //     StructField(
    //       "f",
    //       ArrayType(IntegerType, containsNull = false),
    //       nullable = false))
    //
    // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the
    // other one comes with parquet-protobuf style 1-level unannotated primitive field.
    withTempDir { dir =>
      val avroStylePath = new File(dir, "avro-style").getCanonicalPath
      val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath

      val avroStyleSchema =
        """message avro_style {
          |  required group f (LIST) {
          |    repeated int32 array;
          |  }
          |}
        """.stripMargin

      writeDirect(avroStylePath, avroStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.group {
              rc.field("array", 0) {
                rc.addInteger(0)
                rc.addInteger(1)
              }
            }
          }
        }
      })

      logParquetSchema(avroStylePath)

      val protobufStyleSchema =
        """message protobuf_style {
          |  repeated int32 f;
          |}
        """.stripMargin

      writeDirect(protobufStylePath, protobufStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.addInteger(2)
            rc.addInteger(3)
          }
        }
      })

      logParquetSchema(protobufStylePath)

      checkAnswer(
        sqlContext.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3))))
    }
  }
} 
Example 97
Source File: ParquetProtobufCompatibilitySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 98
Source File: ExchangeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  test("shuffling UnsafeRows in exchange") {
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))),
      input.map(Row.fromTuple)
    )
  }
} 
Example 99
Source File: ExpandSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType

class ExpandSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  private def testExpand(f: SparkPlan => SparkPlan): Unit = {
    val input = (1 to 1000).map(Tuple1.apply)
    val projections = Seq.tabulate(2) { i =>
      Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
    }
    val attributes = projections.head.map(_.toAttribute)
    checkAnswer(
      input.toDF(),
      plan => Expand(projections, attributes, f(plan)),
      input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))
    )
  }

  test("inheriting child row type") {
    val exprs = AttributeReference("a", IntegerType, false)() :: Nil
    val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
    assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")
  }

  test("expanding UnsafeRows") {
    testExpand(ConvertToUnsafe)
  }

  test("expanding SafeRows") {
    testExpand(identity)
  }
} 
Example 100
Source File: SortSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{RandomDataGenerator, Row}



class SortSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  test("basic sorting using ExternalSort") {

    val input = Seq(
      ("Hello", 4, 2.0),
      ("Hello", 1, 1.0),
      ("World", 8, 3.0)
    )

    checkAnswer(
      input.toDF("a", "b", "c"),
      (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child),
      input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
      sortAnswers = false)

    checkAnswer(
      input.toDF("a", "b", "c"),
      (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child),
      input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
      sortAnswers = false)
  }

  test("sort followed by limit") {
    checkThatPlansAgree(
      (1 to 100).map(v => Tuple1(v)).toDF("a"),
      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)),
      (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
      sortAnswers = false
    )
  }

  test("sorting does not crash for large inputs") {
    val sortOrder = 'a.asc :: Nil
    val stringLength = 1024 * 1024 * 2
    checkThatPlansAgree(
      Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
      Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
      ReferenceSort(sortOrder, global = true, _: SparkPlan),
      sortAnswers = false
    )
  }

  test("sorting updates peak execution memory") {
    AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") {
      checkThatPlansAgree(
        (1 to 100).map(v => Tuple1(v)).toDF("a"),
        (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child),
        (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child),
        sortAnswers = false)
    }
  }

  // Test sorting on different data types
  for (
    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
    nullable <- Seq(true, false);
    sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
    randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
  ) {
    test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
      val inputData = Seq.fill(1000)(randomDataGenerator())
      val inputDf = sqlContext.createDataFrame(
        sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
        StructType(StructField("a", dataType, nullable = true) :: Nil)
      )
      checkThatPlansAgree(
        inputDf,
        p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)),
        ReferenceSort(sortOrder, global = true, _: SparkPlan),
        sortAnswers = false
      )
    }
  }
} 
Example 101
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 102
Source File: ExtraStrategiesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package test.org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    sparkContext.parallelize(Seq(row))
  }

  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      sqlContext.experimental.extraStrategies = Nil
    }
  }
} 
Example 103
Source File: PartitionedWriteSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = sqlContext.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      sqlContext.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = sqlContext.range(100)
    val df = base.unionAll(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      sqlContext.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 104
Source File: InfinispanRelation.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
import org.infinispan.client.hotrod.marshall.MarshallerUtil
import org.infinispan.spark.config.ConnectorConfiguration
import org.infinispan.spark.rdd.InfinispanRDD


class InfinispanRelation(context: SQLContext, val parameters: Map[String, String])
  extends BaseRelation with PrunedFilteredScan with Serializable {

   override def sqlContext: SQLContext = context

   lazy val props: ConnectorConfiguration = ConnectorConfiguration(parameters)

   val clazz = {
      val protoEntities = props.getProtoEntities
      val targetEntity = Option(props.getTargetEntity) match {
         case Some(p) => p
         case None => if (protoEntities.nonEmpty) protoEntities.head
         else
            throw new IllegalArgumentException(s"No target entity nor annotated protobuf entities found, check the configuration")
      }
      targetEntity
   }

   @transient lazy val mapper = ObjectMapper.forBean(schema, clazz)

   override def schema: StructType = SchemaProvider.fromJavaBean(clazz)

   override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
      val rdd: InfinispanRDD[AnyRef, AnyRef] = new InfinispanRDD(context.sparkContext, props)
      val serCtx = MarshallerUtil.getSerializationContext(rdd.remoteCache.getRemoteCacheManager)
      val message = serCtx.getMarshaller(clazz).getTypeName
      val projections = toIckle(requiredColumns)
      val predicates = toIckle(filters)

      val select = if (projections.nonEmpty) s"SELECT $projections" else ""
      val from = s"FROM $message"
      val where = if (predicates.nonEmpty) s"WHERE $predicates" else ""

      val query = s"$select $from $where"

      rdd.filterByQuery[AnyRef](query.trim).values.map(mapper(_, requiredColumns))
   }

   def toIckle(columns: Array[String]): String = columns.mkString(",")

   def toIckle(filters: Array[Filter]): String = filters.map(ToIckle).mkString(" AND ")

   private def ToIckle(f: Filter): String = {
      f match {
         case StringEndsWith(a, v) => s"$a LIKE '%$v'"
         case StringContains(a, _) => s"$a LIKE '%$a%'"
         case StringStartsWith(a, v) => s"$a LIKE '$v%'"
         case EqualTo(a, v) => s"$a = '$v'"
         case GreaterThan(a, v) => s"$a > $v"
         case GreaterThanOrEqual(a, v) => s"$a >= $v"
         case LessThan(a, v) => s"$a < $v"
         case LessThanOrEqual(a, v) => s"$a <= $v"
         case IsNull(a) => s"$a is null"
         case IsNotNull(a) => s"$a is not null"
         case In(a, vs) => s"$a IN (${vs.map(v => s"'$v'").mkString(",")})"
         case Not(filter) => s"NOT ${ToIckle(filter)}"
         case And(leftFilter, rightFilter) => s"${ToIckle(leftFilter)} AND ${ToIckle(rightFilter)}"
         case Or(leftFilter, rightFilter) => s"${ToIckle(leftFilter)} OR ${ToIckle(rightFilter)}"
      }
   }
} 
Example 105
Source File: ObjectMapper.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.sql

import java.beans.Introspector

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRowWithSchema}
import org.apache.spark.sql.types.StructType


object ObjectMapper {

   def forBean(schema: StructType, beanClass: Class[_]): (AnyRef, Array[String]) => Row = {
      val beanInfo = Introspector.getBeanInfo(beanClass)
      val attrs = schema.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
      val extractors = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
      val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
         (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
      }
      (from: Any, columns: Array[String]) => {
         if (columns.nonEmpty) {
            from match {
               case _: Array[_] => new GenericRowWithSchema(from.asInstanceOf[Array[Any]], schema)
               case f: Any =>
                  val rowSchema = StructType(Array(schema(columns.head)))
                  new GenericRowWithSchema(Array(f), rowSchema)
            }
         } else {
            new GenericRowWithSchema(methodsToConverts.map { case (e, convert) =>
               val invoke: AnyRef = e.invoke(from)
               convert(invoke)
            }, schema)

         }
      }
   }

} 
Example 106
Source File: BEDRelation.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.datasources.BED

import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoders, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan}
import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs}

class BEDRelation(path: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation
    with PrunedFilteredScan
    with Serializable {

  @transient val logger = Logger.getLogger(this.getClass.getCanonicalName)
  override def schema: org.apache.spark.sql.types.StructType = Encoders.product[org.biodatageeks.formats.BrowserExtensibleData].schema

  private def getValueFromColumn(colName:String, r:Array[String]): Any = {
    colName match {
      case Columns.CONTIG       =>  DataQualityFuncs.cleanContig(r(0) )
      case Columns.START        =>  r(1).toInt + 1 //Convert interval to 1-based
      case Columns.END          =>  r(2).toInt
      case Columns.NAME         =>  if (r.length > 3) Some (r(3)) else None
      case Columns.SCORE        =>  if (r.length > 4) Some (r(4).toInt) else None
      case Columns.STRAND       =>  if (r.length > 5) Some (r(5)) else None
      case Columns.THICK_START  =>  if (r.length > 6) Some (r(6).toInt) else None
      case Columns.THICK_END    =>  if (r.length > 7) Some (r(7).toInt) else None
      case Columns.ITEM_RGB     =>  if (r.length > 8) Some (r(8).split(",").map(_.toInt)) else None
      case Columns.BLOCK_COUNT  =>  if (r.length > 9) Some (r(9).toInt) else None
      case Columns.BLOCK_SIZES  =>  if (r.length > 10) Some (r(10).split(",").map(_.toInt)) else None
      case Columns.BLOCK_STARTS =>  if (r.length > 11) Some (r(11).split(",").map(_.toInt)) else None
      case _                    =>  throw new Exception(s"Unknown column found: ${colName}")
    }
  }
  override def buildScan(requiredColumns:Array[String], filters:Array[Filter]): RDD[Row] = {
    sqlContext
      .sparkContext
      .textFile(path)
      .filter(!_.toLowerCase.startsWith("track"))
      .filter(!_.toLowerCase.startsWith("browser"))
      .map(_.split("\t"))
      .map(r=>
            {
              val record = new Array[Any](requiredColumns.length)
              //requiredColumns.
              for (i <- 0 to requiredColumns.length - 1) {
                record(i) = getValueFromColumn(requiredColumns(i), r)
              }
              Row.fromSeq(record)
            }

    )




  }

} 
Example 107
Source File: VCFRelation.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.datasources.VCF

import io.projectglow.Glow
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources._
import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs}
import org.apache.spark.sql.functions._




class VCFRelation(path: String,
                  normalization_mode: Option[String] = None,
                  ref_genome_path : Option[String] = None )(@transient val sqlContext: SQLContext) extends BaseRelation
  with PrunedScan
  with Serializable
  with Logging {

  val spark: SparkSession = sqlContext.sparkSession

  val cleanContigUDF = udf[String, String](DataQualityFuncs.cleanContig)

  lazy val inputDf: DataFrame = spark
    .read
    .format("vcf")
    .option("splitToBiallelic", "true")
    .load(path)
  lazy val dfNormalized = {
    normalization_mode match {
    case Some(m) => {
      if (m.equalsIgnoreCase("normalize") || m.equalsIgnoreCase("split_and_normalize")
        && ref_genome_path == None) throw new Exception(s"Variant normalization mode specified but ref_genome_path is empty ")
      Glow.transform(m.toLowerCase(), inputDf, Map("reference_genome_path" -> ref_genome_path.get))
    }
    case _ => inputDf
    }
  }.withColumnRenamed("contigName", Columns.CONTIG)
    .withColumnRenamed("start", Columns.START)
    .withColumnRenamed("end", Columns.END)
    .withColumnRenamed("referenceAllele", Columns.REF)
    .withColumnRenamed("alternateAlleles", Columns.ALT)

  lazy val df = dfNormalized
    .withColumn(Columns.CONTIG, cleanContigUDF(dfNormalized(Columns.CONTIG)))

  override def schema: org.apache.spark.sql.types.StructType = {
   df.schema
  }

  override def buildScan(requiredColumns: Array[String] ): RDD[Row] = {

    {
      if (requiredColumns.length > 0)
        df.select(requiredColumns.head, requiredColumns.tail: _*)
      else
        df
    }.rdd


  }

} 
Example 108
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = "NA12878.multichrom.md"
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
    List(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
    )
  )
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    spark
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
    spark.sql(
      s"""
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
    spark.sql(
      s"""
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )
         |
      """.stripMargin)

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
        "null"
      else
        map.map({
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")
    }

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)
  }

} 
Example 109
Source File: Writer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}

object Writer {

  val mapToString = (map: Map[Byte, Short]) => {
    if (map == null)
      "null"
    else
      map.map({
        case (k, v) => k.toChar -> v
      }).toSeq.sortBy(_._1).mkString.replace(" -> ", ":")
  }

  def saveToFile(spark: SparkSession, res: Dataset[Row], path: String) = {
    spark.udf.register("mapToString", mapToString)
    res
      .selectExpr("contig", "pos_start", "pos_end", "ref", "cast(coverage as int)", "mapToString(alts)")
      .coalesce(1)
      .write
      .mode(SaveMode.Overwrite)
      .csv(path)
  }
} 
Example 110
Source File: JoinOrderTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.rangejoins

import java.io.{OutputStreamWriter, PrintWriter}

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
  IntegerType,
  StringType,
  StructField,
  StructType
}
import org.bdgenomics.utils.instrumentation.{
  Metrics,
  MetricsListener,
  RecordedMetrics
}
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim
import org.scalatest.{BeforeAndAfter, FunSuite}

class JoinOrderTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  val schema = StructType(
    Seq(StructField("chr", StringType),
        StructField("start", IntegerType),
        StructField("end", IntegerType)))
  val metricsListener = new MetricsListener(new RecordedMetrics())
  val writer = new PrintWriter(new OutputStreamWriter(System.out))
  before {
    System.setSecurityManager(null)
    spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim(
      spark) :: Nil
    Metrics.initialize(sc)
    val rdd1 = sc
      .textFile(getClass.getResource("/refFlat.txt.bz2").getPath)
      .map(r => r.split('\t'))
      .map(
        r =>
          Row(
            r(2).toString,
            r(4).toInt,
            r(5).toInt
        ))
    val ref = spark.createDataFrame(rdd1, schema)
    ref.createOrReplaceTempView("ref")

    val rdd2 = sc
      .textFile(getClass.getResource("/snp150Flagged.txt.bz2").getPath)
      .map(r => r.split('\t'))
      .map(
        r =>
          Row(
            r(1).toString,
            r(2).toInt,
            r(3).toInt
        ))
    val snp = spark
      .createDataFrame(rdd2, schema)
    snp.createOrReplaceTempView("snp")
  }

  test("Join order - broadcasting snp table") {
    spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder",
                             "true")
    val query =
      s"""
         |SELECT snp.*,ref.* FROM ref JOIN snp
         |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end)
       """.stripMargin

    assert(spark.sql(query).count === 616404L)

  }

  test("Join order - broadcasting ref table") {
    spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder",
                             "true")
    val query =
      s"""
         |SELECT snp.*,ref.* FROM snp JOIN ref
         |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end)
       """.stripMargin
    assert(spark.sql(query).count === 616404L)

  }
  after {
    Metrics.print(writer, Some(metricsListener.metrics.sparkMetrics.stageTimes))
    writer.flush()
    Metrics.stopRecording()
  }
} 
Example 111
Source File: A_1_DataFrameTest.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}


      //定义一个schema
    val schemaString = "name age"
    val fields = schemaString.split(" ")
      .map(filedName => StructField(filedName, StringType, nullable = true))
    val structType = StructType(fields)

    val personRDD = sparkSession.sparkContext.textFile("src/main/resources/sparkresource/people.txt")
      .map(_.split(","))
      //将RDD转换为行
      .map(attr => Row(attr(0), attr(1).trim))
    //将schema应用于RDD,并创建df
    sparkSession.createDataFrame(personRDD,structType).createOrReplaceTempView("people1")
    val dataFrameBySchema = sparkSession.sql("select name,age from people1 where age > 19 ")
    dataFrameBySchema.show()
  }

} 
Example 112
Source File: A_8_MyAverage.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


object A_8_MyAverage extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = StructType(StructField("inputColumn",LongType)::Nil)

  override def bufferSchema: StructType = {
    StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
  }

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)= 0l
    buffer(1)= 0l
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)){
      buffer(0) = buffer.getLong(0) + input.getLong(0)
      buffer(1) = buffer.getLong(1) + 1

    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)

  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().appName("A_8_MyAverage")
      .master("local")
      .getOrCreate()
    sparkSession.udf.register("A_8_MyAverage",A_8_MyAverage)

    val dataFrame = sparkSession.read.json("src/main/resources/sparkresource/employees.json")
    dataFrame.createOrReplaceTempView("employees")

    val result = sparkSession.sql("select A_8_MyAverage(salary) as average_salary from employees")
    result.show()
  }
} 
Example 113
Source File: SparkAvroDecoder.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.avro

import org.apache.log4j.Logger

import java.io.ByteArrayOutputStream

import scala.reflect.runtime.universe._

import org.apache.avro.generic.{ GenericDatumReader, GenericDatumWriter, GenericRecord }
import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.spark.sql.{ Dataset, Encoder, Row }
import org.apache.spark.sql.catalyst.encoders.{ encoderFor, ExpressionEncoder, RowEncoder }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
import org.apache.avro.Schema

import cloudflow.spark.sql.SQLImplicits._

case class EncodedKV(key: String, value: Array[Byte])

case class SparkAvroDecoder[T: Encoder: TypeTag](avroSchema: String) {

  val encoder: Encoder[T]                           = implicitly[Encoder[T]]
  val sqlSchema: StructType                         = encoder.schema
  val encoderForDataColumns: ExpressionEncoder[Row] = RowEncoder(sqlSchema)
  @transient lazy val _avroSchema                   = new Schema.Parser().parse(avroSchema)
  @transient lazy val rowConverter                  = SchemaConverters.createConverterToSQL(_avroSchema, sqlSchema)
  @transient lazy val datumReader                   = new GenericDatumReader[GenericRecord](_avroSchema)
  @transient lazy val decoder                       = DecoderFactory.get
  def decode(bytes: Array[Byte]): Row = {
    val binaryDecoder = decoder.binaryDecoder(bytes, null)
    val record        = datumReader.read(null, binaryDecoder)
    rowConverter(record).asInstanceOf[GenericRow]
  }

}


case class SparkAvroEncoder[T: Encoder: TypeTag](avroSchema: String) {

  @transient lazy val log = Logger.getLogger(getClass.getName)

  val BufferSize = 5 * 1024 // 5 Kb

  val encoder                     = implicitly[Encoder[T]]
  val sqlSchema                   = encoder.schema
  @transient lazy val _avroSchema = new Schema.Parser().parse(avroSchema)

  val recordName                = "topLevelRecord" // ???
  val recordNamespace           = "recordNamespace" // ???
  @transient lazy val converter = AvroConverter.createConverterToAvro(sqlSchema, recordName, recordNamespace)

  // Risk: This process is memory intensive. Might require thread-level buffers to optimize memory usage
  def rowToBytes(row: Row): Array[Byte] = {
    val genRecord = converter(row).asInstanceOf[GenericRecord]
    if (log.isDebugEnabled) log.debug(s"genRecord = $genRecord")
    val datumWriter   = new GenericDatumWriter[GenericRecord](_avroSchema)
    val avroEncoder   = EncoderFactory.get
    val byteArrOS     = new ByteArrayOutputStream(BufferSize)
    val binaryEncoder = avroEncoder.binaryEncoder(byteArrOS, null)
    datumWriter.write(genRecord, binaryEncoder)
    binaryEncoder.flush()
    byteArrOS.toByteArray
  }

  def encode(dataset: Dataset[T]): Dataset[Array[Byte]] =
    dataset.toDF().mapPartitions(rows ⇒ rows.map(rowToBytes)).as[Array[Byte]]

  // Note to self: I'm not sure how heavy this chain of transformations is
  def encodeWithKey(dataset: Dataset[T], keyFun: T ⇒ String): Dataset[EncodedKV] = {
    val encoder             = encoderFor[T]
    implicit val rowEncoder = RowEncoder(encoder.schema).resolveAndBind()
    dataset.map { value ⇒
      val key         = keyFun(value)
      val internalRow = encoder.toRow(value)
      val row         = rowEncoder.fromRow(internalRow)
      val bytes       = rowToBytes(row)
      EncodedKV(key, bytes)
    }
  }

} 
Example 114
Source File: SparkAvroDecoderSuite.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.avro

import org.apache.spark.sql.Row
import org.scalatest.{ Matchers, WordSpec }
import cloudflow.streamlets.avro.AvroCodec
import cloudflow.streamlets.Codec

import cloudflow.spark.sql.SQLImplicits._

class SparkAvroDecoderSuite extends WordSpec with Matchers {

  val simpleCodec: Codec[Simple]   = new AvroCodec(Simple.SCHEMA$)
  val complexCodec: Codec[Complex] = new AvroCodec(Complex.SCHEMA$)

  "SparkAvroDecoder" should {
    "decode a simple case class" in {

      val sample  = Simple("sphere")
      val encoded = simpleCodec.encode(sample)

      val decoder = new SparkAvroDecoder[Simple](Simple.SCHEMA$.toString)
      val result  = decoder.decode(encoded)
      result.getAs[String](0) should be(sample.name)
    }

    "decode a complex case class" in {

      val complex = Complex(Simple("room"), 101, 0.01, 0.001f)
      val encoded = complexCodec.encode(complex)

      val decoder = new SparkAvroDecoder[Complex](Complex.SCHEMA$.toString)
      val result  = decoder.decode(encoded)
      result.getAs[Row](0).getAs[String](0) should be(complex.simple.name)
      result.getAs[Int](1) should be(complex.count)
      result.getAs[Double](2) should be(complex.range)
      result.getAs[Float](3) should be(complex.error)
    }

  }

} 
Example 115
Source File: PushDownJdbcRDD.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.jdbc.utilities

import java.sql.{Connection, ResultSet}

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.JdbcRDD
import org.apache.spark.sql.Row

import com.paypal.gimel.common.utilities.GenericUtils
import com.paypal.gimel.logger.Logger


class PushDownJdbcRDD(sc: SparkContext,
                      getConnection: () => Connection,
                      sql: String,
                      mapRow: ResultSet => Row = PushDownJdbcRDD.resultSetToRow)
  extends JdbcRDD[Row](sc, getConnection, sql, 0, 100, 1, mapRow)
    with Logging {

  override def compute(thePart: Partition,
                       context: TaskContext): Iterator[Row] = {
    val logger = Logger(this.getClass.getName)
    val functionName = s"[QueryHash: ${sql.hashCode}]"
    logger.info(s"Proceeding to execute push down query $functionName: $sql")
    val queryResult: String = GenericUtils.time(functionName, Some(logger)) {
      JDBCConnectionUtility.withResources(getConnection()) { connection =>
        JdbcAuxiliaryUtilities.executeQueryAndReturnResultString(
          sql,
          connection
        )
      }
    }
    Seq(Row(queryResult)).iterator
  }
}

object PushDownJdbcRDD {
  def resultSetToRow(rs: ResultSet): Row = {
    Row(rs.getString(0))
  }
} 
Example 116
Source File: DataStream.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.kafka2

import scala.language.implicitConversions

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.streaming.DataStreamWriter

import com.paypal.gimel.datastreamfactory.{GimelDataStream2, StructuredStreamingResult}
import com.paypal.gimel.kafka2.conf.KafkaClientConfiguration
import com.paypal.gimel.kafka2.reader.KafkaStreamConsumer
import com.paypal.gimel.kafka2.writer.KafkaStreamProducer
import com.paypal.gimel.logger.Logger

class DataStream(sparkSession: SparkSession) extends GimelDataStream2(sparkSession: SparkSession) {

  // GET LOGGER
  val logger = Logger()
  logger.info(s"Initiated --> ${this.getClass.getName}")

  
private class DataStreamException(message: String, cause: Throwable)
  extends RuntimeException(message) {
  if (cause != null) {
    initCause(cause)
  }

  def this(message: String) = this(message, null)
} 
Example 117
Source File: KafkaStreamProducer.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.kafka2.writer

import java.util.Properties

import scala.collection.JavaConverters._
import scala.collection.immutable.Map
import scala.language.implicitConversions

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.streaming.DataStreamWriter

import com.paypal.gimel.common.conf.GimelConstants
import com.paypal.gimel.kafka2.conf.{KafkaClientConfiguration, KafkaConstants}
import com.paypal.gimel.kafka2.utilities.{KafkaOptionsLoaderUtils, KafkaUtilitiesException}

object KafkaStreamProducer {
  val logger = com.paypal.gimel.logger.Logger()

  
  def produceStreamToKafka(conf: KafkaClientConfiguration, dataFrame: DataFrame): DataStreamWriter[Row] = {
    def MethodName: String = new Exception().getStackTrace().apply(1).getMethodName()
    logger.info(" @Begin --> " + MethodName)

    val kafkaProps: Properties = conf.kafkaProducerProps
    logger.info(s"Kafka Props for Producer -> ${kafkaProps.asScala.mkString("\n")}")
    logger.info("Begin Publishing to Kafka....")
    // Retrieve kafka options from OptionsLoader if specified
    val kafkaTopicsOptionsMap : Map[String, Map[String, String]] = KafkaOptionsLoaderUtils.getAllKafkaTopicsOptions(conf)
    logger.info("kafkaTopicsOptionsMap -> " + kafkaTopicsOptionsMap)
    try {
      val eachKafkaTopicToOptionsMap = KafkaOptionsLoaderUtils.getEachKafkaTopicToOptionsMap(kafkaTopicsOptionsMap)
      val kafkaTopicOptions = eachKafkaTopicToOptionsMap.get(conf.kafkaTopics)
      kafkaTopicOptions match {
        case None =>
          throw new IllegalStateException(s"""Could not load options for the kafka topic -> $conf.kafkaTopics""")
        case Some(kafkaOptions) =>
          dataFrame
            .writeStream
            .format(KafkaConstants.KAFKA_FORMAT)
            .option(KafkaConstants.KAFKA_TOPIC, conf.kafkaTopics)
            .option(GimelConstants.STREAMING_CHECKPOINT_LOCATION, conf.streamingCheckpointLocation)
            .outputMode(conf.streamingOutputMode)
            .options(kafkaOptions)
      }
    }
    catch {
      case ex: Throwable => {
        ex.printStackTrace()
        val msg =
          s"""
             |kafkaTopic -> ${conf.kafkaTopics}
             |kafkaParams --> ${kafkaProps.asScala.mkString("\n")}}
          """.stripMargin
        throw new KafkaUtilitiesException(s"Failed While Pushing Data Into Kafka \n ${msg}")
      }
    }
  }
} 
Example 118
Source File: RestApiConsumer.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.restapi.reader

import scala.language.implicitConversions

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import com.paypal.gimel.common.gimelservices.GimelServiceUtilities
import com.paypal.gimel.logger.Logger
import com.paypal.gimel.restapi.conf.RestApiClientConfiguration



object RestApiConsumer {

  val logger: Logger = Logger()
  val utils: GimelServiceUtilities = GimelServiceUtilities()

  def consume(sparkSession: SparkSession, conf: RestApiClientConfiguration): DataFrame = {
    def MethodName: String = new Exception().getStackTrace().apply(1).getMethodName()
    logger.info(" @Begin --> " + MethodName)

    val responsePayload = conf.httpsFlag match {
      case false => utils.get(conf.resolvedUrl.toString)
      case true => utils.httpsGet(conf.resolvedUrl.toString)
    }
    conf.parsePayloadFlag match {
      case false =>
        logger.info("NOT Parsing payload.")
        val rdd: RDD[String] = sparkSession.sparkContext.parallelize(Seq(responsePayload))
        val rowRdd: RDD[Row] = rdd.map(Row(_))
        val field: StructType = StructType(Seq(StructField(conf.payloadFieldName, StringType)))
        sparkSession.sqlContext.createDataFrame(rowRdd, field)
      case true =>
        logger.info("Parsing payload to fields - as requested.")
        val rdd: RDD[String] = sparkSession.sparkContext.parallelize(Seq(responsePayload))
        sparkSession.sqlContext.read.json(rdd)
    }
  }

} 
Example 119
Source File: UpdateCarbonTableTestCaseWithBadRecord.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.iud

import org.apache.spark.sql.{Row, SaveMode}
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.common.constants.LoggerAction
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest

class UpdateCarbonTableTestCaseWithBadRecord extends QueryTest with BeforeAndAfterAll {
  override def beforeAll {

    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_BAD_RECORDS_ACTION , LoggerAction.FORCE.name())
  }


  test("test update operation with Badrecords action as force.") {
    sql("""drop table if exists badtable""").show
    sql("""create table badtable (c1 string,c2 int,c3 string,c5 string) STORED AS carbondata""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/IUD/badrecord.csv' INTO table badtable""")
    sql("""update badtable d  set (d.c2) = (d.c2 / 1)""").show()
    checkAnswer(
      sql("""select c1,c2,c3,c5 from badtable"""),
      Seq(Row("ravi",null,"kiran","huawei"),Row("manohar",null,"vanam","huawei"))
    )
    sql("""drop table badtable""").show


  }
  test("test update operation with Badrecords action as FAIL.") {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_BAD_RECORDS_ACTION , LoggerAction.FAIL.name())
    sql("""drop table if exists badtable""").show
    sql("""create table badtable (c1 string,c2 int,c3 string,c5 string) STORED AS carbondata""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/IUD/badrecord.csv' INTO table badtable""")
    val exec = intercept[Exception] {
      sql("""update badtable d  set (d.c2) = (d.c2 / 1)""").show()
    }
    checkAnswer(
      sql("""select c1,c2,c3,c5 from badtable"""),
      Seq(Row("ravi",2,"kiran","huawei"),Row("manohar",4,"vanam","huawei"))
    )
    sql("""drop table badtable""").show

  }

  override def afterAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_BAD_RECORDS_ACTION , LoggerAction.FORCE.name())
  }
} 
Example 120
Source File: TestUpdateAndDeleteWithLargeData.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.iud

import java.text.SimpleDateFormat

import org.apache.spark.sql.test.util.QueryTest
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties

class TestUpdateAndDeleteWithLargeData extends QueryTest with BeforeAndAfterAll {
  var df: DataFrame = _

  override def beforeAll {
    dropTable()
    buildTestData()
  }

  private def buildTestData(): Unit = {

    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT, "yyyy-MM-dd")

    // Simulate data and write to table orders
    import sqlContext.implicits._

    val sdf = new SimpleDateFormat("yyyy-MM-dd")
    df = sqlContext.sparkSession.sparkContext.parallelize(1 to 1500000)
      .map(value => (value, new java.sql.Date(sdf.parse("2015-07-" + (value % 10 + 10)).getTime),
        "china", "aaa" + value, "phone" + 555 * value, "ASD" + (60000 + value), 14999 + value,
        "ordersTable" + value))
      .toDF("o_id", "o_date", "o_country", "o_name",
        "o_phonetype", "o_serialname", "o_salary", "o_comment")
    createTable()

  }

  private def createTable(): Unit = {
    df.write
      .format("carbondata")
      .option("tableName", "orders")
      .option("tempCSV", "true")
      .option("compress", "true")
      .mode(SaveMode.Overwrite)
      .save()
  }

  private def dropTable() = {
    sql("DROP TABLE IF EXISTS orders")

  }

  test("test the update and delete delete functionality for large data") {

    sql(
      """
            update ORDERS set (o_comment) = ('yyy')""").show()
    checkAnswer(sql(
      """select o_comment from orders limit 2 """), Seq(Row("yyy"), Row("yyy")))

    sql("delete from orders where exists (select 1 from orders)")

    checkAnswer(sql(
      """
           SELECT count(*) FROM orders
           """), Row(0))
  }

} 
Example 121
Source File: IntegerDataTypeTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.sortexpr

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll


class IntegerDataTypeTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("CREATE TABLE inttypetablesort (empno int, workgroupcategory string, deptno int, projectcode int,attendance int) STORED AS carbondata")
    sql(s"""LOAD DATA local inpath '$resourcesPath/data.csv' INTO TABLE inttypetablesort OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')""")
  }

  test("select empno from inttypetablesort") {
    checkAnswer(
      sql("select empno from inttypetablesort"),
      Seq(Row(11), Row(12), Row(13), Row(14), Row(15), Row(16), Row(17), Row(18), Row(19), Row(20)))
  }

  override def afterAll {
    sql("drop table inttypetablesort")
  }
} 
Example 122
Source File: CacheRefreshTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.cloud

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.hive.CarbonHiveIndexMetadataUtil
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

class CacheRefreshTestCase extends QueryTest with BeforeAndAfterAll {

  override protected def beforeAll(): Unit = {
    sql("drop database if exists cachedb cascade")
    sql("create database cachedb")
    sql("use cachedb")
  }

  override protected def afterAll(): Unit = {
    sql("use default")
    sql("drop database if exists cachedb cascade")
  }

  test("test cache refresh") {
    sql("create table tbl_cache1(col1 string, col2 int, col3 int) using carbondata")
    sql("insert into tbl_cache1 select 'a', 123, 345")
    CarbonHiveIndexMetadataUtil.invalidateAndDropTable(
      "cachedb", "tbl_cache1", sqlContext.sparkSession)
    // discard cached table info in cachedDataSourceTables
    val tableIdentifier = TableIdentifier("tbl_cache1", Option("cachedb"))
    sqlContext.sparkSession.sessionState.catalog.refreshTable(tableIdentifier)
    sql("create table tbl_cache1(col1 string, col2 int, col3 int) using carbondata")
    sql("delete from tbl_cache1")
    sql("insert into tbl_cache1 select 'b', 123, 345")
    checkAnswer(sql("select * from tbl_cache1"),
      Seq(Row("b", 123, 345)))
  }
} 
Example 123
Source File: TestDescribeTable.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.describeTable

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties


class TestDescribeTable extends QueryTest with BeforeAndAfterAll {

  override def beforeAll: Unit = {
    sql("DROP TABLE IF EXISTS Desc1")
    sql("DROP TABLE IF EXISTS Desc2")
    sql("drop table if exists a")
    sql("CREATE TABLE Desc1(Dec1Col1 String, Dec1Col2 String, Dec1Col3 int, Dec1Col4 double) STORED AS carbondata")
    sql("DESC Desc1")
    sql("DROP TABLE Desc1")
    sql("CREATE TABLE Desc1(Dec2Col1 BigInt, Dec2Col2 String, Dec2Col3 Bigint, Dec2Col4 Decimal) STORED AS carbondata")
    sql("CREATE TABLE Desc2(Dec2Col1 BigInt, Dec2Col2 String, Dec2Col3 Bigint, Dec2Col4 Decimal) STORED AS carbondata")
  }

  test("test describe table") {
    checkAnswer(sql("DESC Desc1"), Seq(Row("dec2col1","bigint",null),
      Row("dec2col2","string",null),
      Row("dec2col3","bigint",null),
      Row("dec2col4","decimal(10,0)",null)))
  }

  test("test describe formatted table") {
    checkExistence(sql("DESC FORMATTED Desc1"), true, "Table Block Size")
  }

  test("test describe formatted for partition table") {
    sql("create table a(a string) partitioned by (b int) STORED AS carbondata")
    sql("insert into a values('a',1)")
    sql("insert into a values('a',2)")
    val desc = sql("describe formatted a").collect()
    assert(desc(desc.indexWhere(_.get(0).toString.contains("#Partition")) + 2).get(0).toString.contains("b"))
    val descPar = sql("describe formatted a partition(b=1)").collect
    descPar.find(_.get(0).toString.contains("Partition Value:")) match {
      case Some(row) => assert(row.get(1).toString.contains("1"))
      case None => fail("Partition Value not found in describe formatted")
    }
    descPar.find(_.get(0).toString.contains("Location:")) match {
      case Some(row) => assert(row.get(1).toString.contains("target/warehouse/a/b=1"))
      case None => fail("Partition Location not found in describe formatted")
    }
    assert(descPar.exists(_.toString().contains("Partition Parameters:")))
  }

  override def afterAll: Unit = {
    sql("DROP TABLE Desc1")
    sql("DROP TABLE Desc2")
    sql("drop table if exists a")
    sql("drop table if exists b")
    CarbonProperties.getInstance().addProperty(CarbonCommonConstants.COMPRESSOR,
      CarbonCommonConstants.DEFAULT_COMPRESSOR)
  }

} 
Example 124
Source File: DateDataTypeNullDataTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.sql.Date

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class DateDataTypeNullDataTest extends QueryTest with BeforeAndAfterAll {
  var hiveContext: HiveContext = _

  override def beforeAll {
    try {
      sql(
        """CREATE TABLE IF NOT EXISTS timestampTyeNullData
                     (ID Int, dateField date, country String,
                     name String, phonetype String, serialname String, salary Int)
                    STORED AS carbondata"""
      )

      CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT, "yyyy/MM/dd")
      val csvFilePath = s"$resourcesPath/datasamplenull.csv"
      sql("LOAD DATA LOCAL INPATH '" + csvFilePath + "' INTO TABLE timestampTyeNullData").collect();

    } catch {
      case x: Throwable =>
        x.printStackTrace()
        CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
          CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    }
  }

  test("SELECT max(dateField) FROM timestampTyeNullData where dateField is not null") {
    checkAnswer(
      sql("SELECT max(dateField) FROM timestampTyeNullData where dateField is not null"),
      Seq(Row(Date.valueOf("2015-07-23"))
      )
    )
  }
  test("SELECT * FROM timestampTyeNullData where dateField is null") {
    checkAnswer(
      sql("SELECT dateField FROM timestampTyeNullData where dateField is null"),
      Seq(Row(null)
      ))
  }

  override def afterAll {
    sql("drop table timestampTyeNullData")
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
        CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "false")
  }

} 
Example 125
Source File: TimestampNoDictionaryColumnCastTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary


import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties


class TimestampNoDictionaryColumnCastTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
        CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
        CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)

      sql("drop table if exists timestamp_nodictionary")
    sql("drop table if exists datetype")
      sql(
        """
         CREATE TABLE IF NOT EXISTS timestamp_nodictionary
        (timestamptype timestamp) STORED AS carbondata"""
      )
      val csvFilePath = s"$resourcesPath/timestampdatafile.csv"
      sql(s"LOAD DATA LOCAL INPATH '$csvFilePath' into table timestamp_nodictionary")
//
    sql(
      """
         CREATE TABLE IF NOT EXISTS datetype
        (datetype1 date) STORED AS carbondata"""
    )
    val csvFilePath1 = s"$resourcesPath/datedatafile.csv"
    sql(s"LOAD DATA LOCAL INPATH '$csvFilePath1' into table datetype")
  }

  ignore("select count(*) from timestamp_nodictionary where timestamptype BETWEEN '2018-09-11' AND '2018-09-16'") {
    checkAnswer(
      sql("select count(*) from timestamp_nodictionary where timestamptype BETWEEN '2018-09-11' AND '2018-09-16'"),
      Seq(Row(6)
      )
    )
  }
//
  test("select count(*) from datetype where datetype1 BETWEEN '2018-09-11' AND '2018-09-16'") {
    checkAnswer(
      sql("select count(*) from datetype where datetype1 BETWEEN '2018-09-11' AND '2018-09-16'"),
      Seq(Row(6)
      )
    )
  }

  override def afterAll {
    sql("drop table timestamp_nodictionary")
    sql("drop table if exists datetype")
  }
} 
Example 126
Source File: TimestampDataTypeNullDataTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.io.File
import java.sql.Timestamp

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.keygenerator.directdictionary.timestamp.TimeStampGranularityConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class TimestampDataTypeNullDataTest extends QueryTest with BeforeAndAfterAll {
  var hiveContext: HiveContext = _

  override def beforeAll {
    try {
      CarbonProperties.getInstance()
        .addProperty(TimeStampGranularityConstants.CARBON_CUTOFF_TIMESTAMP, "2000-12-13 02:10.00.0")
      CarbonProperties.getInstance()
        .addProperty(TimeStampGranularityConstants.CARBON_TIME_GRANULARITY,
          TimeStampGranularityConstants.TIME_GRAN_SEC.toString
        )
      sql(
        """CREATE TABLE IF NOT EXISTS timestampTyeNullData
                     (ID Int, dateField Timestamp, country String,
                     name String, phonetype String, serialname String, salary Int)
                    STORED AS carbondata"""
      )

      CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT, "yyyy/MM/dd")
      val csvFilePath = s"$resourcesPath/datasamplenull.csv"
      sql("LOAD DATA LOCAL INPATH '" + csvFilePath + "' INTO TABLE timestampTyeNullData").collect();

    } catch {
      case x: Throwable => CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
          CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
    }
  }

  test("SELECT max(dateField) FROM timestampTyeNullData where dateField is not null") {
    checkAnswer(
      sql("SELECT max(dateField) FROM timestampTyeNullData where dateField is not null"),
      Seq(Row(Timestamp.valueOf("2015-07-23 00:00:00.0"))
      )
    )
  }
  test("SELECT * FROM timestampTyeNullData where dateField is null") {
    checkAnswer(
      sql("SELECT dateField FROM timestampTyeNullData where dateField is null"),
      Seq(Row(null)
      ))
  }

  override def afterAll {
    sql("drop table timestampTyeNullData")
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
        CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "false")
  }

} 
Example 127
Source File: DateDataTypeDirectDictionaryWithOffHeapSortDisabledTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.sql.Date

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.common.constants.LoggerAction
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties


class DateDataTypeDirectDictionaryWithOffHeapSortDisabledTest
  extends QueryTest with BeforeAndAfterAll {
  private val originOffHeapSortStatus: String = CarbonProperties.getInstance()
    .getProperty(CarbonCommonConstants.ENABLE_OFFHEAP_SORT,
      CarbonCommonConstants.ENABLE_OFFHEAP_SORT_DEFAULT)

  override def beforeAll {
    try {
      CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "true")
      CarbonProperties.getInstance().addProperty(
        CarbonCommonConstants.CARBON_BAD_RECORDS_ACTION, LoggerAction.FORCE.name())
      CarbonProperties.getInstance().addProperty(CarbonCommonConstants.ENABLE_OFFHEAP_SORT, "false")

      sql("drop table if exists directDictionaryTable ")
      sql("CREATE TABLE if not exists directDictionaryTable (empno int,doj date, salary int) " +
        "STORED AS carbondata")

      CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT, "yyyy-MM-dd")
      val csvFilePath = s"$resourcesPath/datasamplefordate.csv"
      sql("LOAD DATA local inpath '" + csvFilePath + "' INTO TABLE directDictionaryTable OPTIONS" +
          "('DELIMITER'= ',', 'QUOTECHAR'= '\"')" )
    } catch {
      case x: Throwable =>
        x.printStackTrace()
        CarbonProperties.getInstance().addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
            CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    }
  }

  test("test direct dictionary for not null condition") {
    checkAnswer(sql("select doj from directDictionaryTable where doj is not null"),
      Seq(Row(Date.valueOf("2016-03-14")), Row(Date.valueOf("2016-04-14"))))
  }

  override def afterAll {
    sql("drop table directDictionaryTable")
    CarbonProperties.getInstance().addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
        CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "false")
    CarbonProperties.getInstance().addProperty(CarbonCommonConstants.ENABLE_OFFHEAP_SORT,
      originOffHeapSortStatus)
  }
} 
Example 128
Source File: TimestampDataTypeDirectDictionaryWithNoDictTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.sql.Timestamp

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.keygenerator.directdictionary.timestamp.TimeStampGranularityConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class TimestampDataTypeDirectDictionaryWithNoDictTestCase extends QueryTest with BeforeAndAfterAll {
  var hiveContext: HiveContext = _

  override def beforeAll {
    CarbonProperties.getInstance()
      .addProperty(TimeStampGranularityConstants.CARBON_CUTOFF_TIMESTAMP, "2000-12-13 02:10.00.0")
    CarbonProperties.getInstance()
      .addProperty(TimeStampGranularityConstants.CARBON_TIME_GRANULARITY,
        TimeStampGranularityConstants.TIME_GRAN_SEC.toString
      )
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "true")
    sql(
      """
         CREATE TABLE IF NOT EXISTS directDictionaryTable
        (empno String, doj Timestamp, salary Int)
         STORED AS carbondata"""
    )
    val csvFilePath = s"$resourcesPath/datasample.csv"
    sql("LOAD DATA local inpath '" + csvFilePath + "' INTO TABLE directDictionaryTable OPTIONS"
        + "('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
  }

  test("select doj from directDictionaryTable") {
    checkAnswer(
      sql("select doj from directDictionaryTable"),
      Seq(Row(Timestamp.valueOf("2016-03-14 15:00:09.0")),
        Row(Timestamp.valueOf("2016-04-14 15:00:09.0")),
        Row(null)
      )
    )
  }


  test("select doj from directDictionaryTable with equals filter") {
    checkAnswer(
      sql("select doj from directDictionaryTable where doj='2016-03-14 15:00:09'"),
      Seq(Row(Timestamp.valueOf("2016-03-14 15:00:09")))
    )

  }

  test("select doj from directDictionaryTable with greater than filter") {
    checkAnswer(
      sql("select doj from directDictionaryTable where doj>'2016-03-14 15:00:09'"),
      Seq(Row(Timestamp.valueOf("2016-04-14 15:00:09")))
    )

  }


  override def afterAll {
    sql("drop table directDictionaryTable")
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "false")
  }
} 
Example 129
Source File: TimestampNoDictionaryColumnTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.sql.Timestamp

import org.apache.spark.sql.Row
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class TimestampNoDictionaryColumnTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT, "dd-MM-yyyy")

    sql("drop table if exists timestamp_nodictionary")
    sql(
      """
         CREATE TABLE IF NOT EXISTS timestamp_nodictionary
        (empno int, empname String, designation String, doj Timestamp, workgroupcategory int,
        workgroupcategoryname String,
         projectcode int, projectjoindate Timestamp, projectenddate Timestamp, attendance int,
         utilization int, salary Int) STORED AS carbondata"""
    )

    val csvFilePath = s"$resourcesPath/data_beyond68yrs.csv"
    sql("LOAD DATA local inpath '" + csvFilePath + "' INTO TABLE timestamp_nodictionary OPTIONS"
        + "('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
  }

  test("select projectjoindate, projectenddate from timestamp_nodictionary") {
    checkAnswer(
      sql("select projectjoindate, projectenddate from timestamp_nodictionary"),
      Seq(Row(Timestamp.valueOf("2000-01-29 00:00:00.0"), Timestamp.valueOf("2016-06-29 00:00:00.0")),
        Row(Timestamp.valueOf("1800-02-17 00:00:00.0"), Timestamp.valueOf("1900-11-29 00:00:00.0")),
        Row(null, Timestamp.valueOf("2016-05-29 00:00:00.0")),
        Row(null, Timestamp.valueOf("2016-11-30 00:00:00.0")),
        Row(Timestamp.valueOf("3000-10-22 00:00:00.0"), Timestamp.valueOf("3002-11-15 00:00:00.0")),
        Row(Timestamp.valueOf("1802-06-29 00:00:00.0"), Timestamp.valueOf("1902-12-30 00:00:00.0")),
        Row(null, Timestamp.valueOf("2016-12-30 00:00:00.0")),
        Row(Timestamp.valueOf("2038-11-14 00:00:00.0"), Timestamp.valueOf("2041-12-29 00:00:00.0")),
        Row(null, null),
        Row(Timestamp.valueOf("2014-09-15 00:00:00.0"), Timestamp.valueOf("2016-05-29 00:00:00.0"))
      )
    )
  }


  test("select projectjoindate, projectenddate from timestamp_nodictionary where in filter") {
    checkAnswer(
      sql("select projectjoindate, projectenddate from timestamp_nodictionary where projectjoindate in" +
          "('1800-02-17 00:00:00','3000-10-22 00:00:00') or projectenddate in ('1900-11-29 00:00:00'," +
          "'3002-11-15 00:00:00','2041-12-29 00:00:00')"),
      Seq(Row(Timestamp.valueOf("1800-02-17 00:00:00.0"), Timestamp.valueOf("1900-11-29 00:00:00.0")),
        Row(Timestamp.valueOf("3000-10-22 00:00:00.0"), Timestamp.valueOf("3002-11-15 00:00:00.0")),
        Row(Timestamp.valueOf("2038-11-14 00:00:00.0"), Timestamp.valueOf("2041-12-29 00:00:00.0")))
    )

  }


  override def afterAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
        CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
    sql("drop table timestamp_nodictionary")
  }
} 
Example 130
Source File: DateDataTypeDirectDictionaryWithNoDictTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.directdictionary

import java.io.File
import java.sql.Date

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class DateDataTypeDirectDictionaryWithNoDictTestCase extends QueryTest with BeforeAndAfterAll {
  var hiveContext: HiveContext = _

  override def beforeAll {
    try {
      sql("drop table if exists directDictionaryTable")
      CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "true")
      sql(
        """
         CREATE TABLE IF NOT EXISTS directDictionaryTable
        (empno String, doj Date, salary Int)
         STORED AS carbondata """
      )

      CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT, "yyyy-MM-dd")
      val csvFilePath = s"$resourcesPath/datasample.csv"
      println(csvFilePath)
      sql("LOAD DATA local inpath '" + csvFilePath + "' INTO TABLE directDictionaryTable OPTIONS"
        + "('DELIMITER'= ',', 'QUOTECHAR'= '\"')");
    } catch {
      case x: Throwable =>
        x.printStackTrace()
        CarbonProperties.getInstance()
        .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
          CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    }
  }

  test("select doj from directDictionaryTable") {
    sql("select doj from directDictionaryTable")
    checkAnswer(
      sql("select doj from directDictionaryTable"),
      Seq(Row(Date.valueOf("2016-03-14")),
        Row(Date.valueOf("2016-04-14")),
        Row(null)
      )
    )
  }


  test("select doj from directDictionaryTable with equals filter") {
    sql("select doj from directDictionaryTable where doj='2016-03-14 15:00:09'")
    checkAnswer(
      sql("select doj from directDictionaryTable where doj='2016-03-14'"),
      Seq(Row(Date.valueOf("2016-03-14")))
    )

  }

  test("select doj from directDictionaryTable with greater than filter") {
    sql("select doj from directDictionaryTable where doj>'2016-03-14 15:00:09'")
    checkAnswer(
      sql("select doj from directDictionaryTable where doj>'2016-03-14 15:00:09'"),
      Seq(Row(Date.valueOf("2016-04-14")))
    )

  }


  override def afterAll {
    sql("drop table directDictionaryTable")
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_DATE_FORMAT,
        CarbonCommonConstants.CARBON_DATE_DEFAULT_FORMAT)
    CarbonProperties.getInstance().addProperty("carbon.direct.dictionary", "false")
  }
} 
Example 131
Source File: OrderByLimitTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.joinquery

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll



class OrderByLimitTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql(
      "CREATE TABLE carbon1 (empno int, empname String, designation String, doj Timestamp, " +
      "workgroupcategory int, workgroupcategoryname String, deptno int, deptname String, " +
      "projectcode int, projectjoindate Timestamp, projectenddate Timestamp,attendance int," +
      "utilization int,salary int) STORED AS carbondata")
    sql(
      s"""LOAD DATA local inpath '$resourcesPath/data.csv' INTO TABLE carbon1 OPTIONS
          |('DELIMITER'= ',', 'QUOTECHAR'= '\"')""".stripMargin);

    sql(
      "CREATE TABLE carbon2 (empno int, empname String, designation String, doj Timestamp, " +
      "workgroupcategory int, workgroupcategoryname String, deptno int, deptname String, " +
      "projectcode int, projectjoindate Timestamp, projectenddate Timestamp,attendance int," +
      "utilization int,salary int) STORED AS carbondata")
    sql(
      s"""LOAD DATA local inpath '$resourcesPath/data.csv' INTO TABLE carbon2 OPTIONS
          |('DELIMITER'= ',', 'QUOTECHAR'= '\"')""".stripMargin);

    sql(
      "CREATE TABLE carbon1_hive (empno int, empname String, designation String, doj Timestamp, " +
      "workgroupcategory int, workgroupcategoryname String, deptno int, deptname String, " +
      "projectcode int, projectjoindate Timestamp, projectenddate Timestamp,attendance int," +
      "utilization int,salary int) row format delimited fields terminated by ','")
    sql(
      s"LOAD DATA local inpath '$resourcesPath/datawithoutheader.csv' INTO TABLE carbon1_hive ")

    sql(
      "CREATE TABLE carbon2_hive (empno int, empname String, designation String, doj Timestamp, " +
      "workgroupcategory int, workgroupcategoryname String, deptno int, deptname String, " +
      "projectcode int, projectjoindate Timestamp, projectenddate Timestamp,attendance int," +
      "utilization int,salary int) row format delimited fields terminated by ','")
    sql(
      s"LOAD DATA local inpath '$resourcesPath/datawithoutheader.csv' INTO TABLE carbon2_hive ");


  }

  test("test join with orderby limit") {
    checkAnswer(
      sql(
        "select a.empno,a.empname,a.workgroupcategoryname from carbon1 a full outer join carbon2 " +
        "b on substr(a.workgroupcategoryname," +
        "1,3)" +
        "=substr(b.workgroupcategoryname,1,3) order by a.empname limit 5"),
      sql(
        "select a.empno,a.empname,a.workgroupcategoryname from carbon1_hive a full outer join " +
        "carbon2_hive b on " +
        "substr(a" +
        ".workgroupcategoryname,1,3)=substr(b.workgroupcategoryname,1,3) order by a.empname limit" +
        " 5")
    )
  }

  override def afterAll {
    sql("drop table carbon1")
    sql("drop table carbon2")
    sql("drop table carbon1_hive")
    sql("drop table carbon2_hive")
  }
} 
Example 132
Source File: IntegerDataTypeTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.joinquery

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll


class IntegerDataTypeTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("CREATE TABLE integertypetablejoin (empno int, workgroupcategory string, deptno int, projectcode int,attendance int) STORED AS carbondata")
    sql(s"""LOAD DATA local inpath '$resourcesPath/data.csv' INTO TABLE integertypetablejoin OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')""")
  }

  test("select empno from integertypetablejoin") {
    checkAnswer(
      sql("select empno from integertypetablejoin"),
      Seq(Row(11), Row(12), Row(13), Row(14), Row(15), Row(16), Row(17), Row(18), Row(19), Row(20)))
  }

  override def afterAll {
    sql("drop table integertypetablejoin")
  }
} 
Example 133
Source File: NullMeasureValueTestCaseAggregate.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.measurenullvalue

import org.apache.spark.sql.Row
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest

class NullMeasureValueTestCaseAggregate extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("drop table IF EXISTS t3")
    sql(
      "CREATE TABLE IF NOT EXISTS t3 (ID Int, date Timestamp, country String, name String, " +
        "phonetype String, serialname String, salary Int) STORED AS carbondata"
    )
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT, "yyyy/MM/dd")
    sql(s"LOAD DATA LOCAL INPATH '$resourcesPath/nullmeasurevalue.csv' into table t3");
  }

  test("select count(salary) from t3") {
    checkAnswer(
      sql("select count(salary) from t3"),
      Seq(Row(0)))
  }
  test("select count(ditinct salary) from t3") {
    checkAnswer(
      sql("select count(distinct salary) from t3"),
      Seq(Row(0)))
  }
  
  test("select sum(salary) from t3") {
    checkAnswer(
      sql("select sum(salary) from t3"),
      Seq(Row(null)))
  }
  test("select avg(salary) from t3") {
    checkAnswer(
      sql("select avg(salary) from t3"),
      Seq(Row(null)))
  }
  
   test("select max(salary) from t3") {
    checkAnswer(
      sql("select max(salary) from t3"),
      Seq(Row(null)))
   }
   test("select min(salary) from t3") {
    checkAnswer(
      sql("select min(salary) from t3"),
      Seq(Row(null)))
   }
   test("select sum(distinct salary) from t3") {
    checkAnswer(
      sql("select sum(distinct salary) from t3"),
      Seq(Row(null)))
   }
   
  override def afterAll {
    sql("drop table t3")
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
        CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
  }
} 
Example 134
Source File: IntegerDataTypeTestCase.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.detailquery

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll


class IntegerDataTypeTestCase extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("CREATE TABLE integertypetable (empno int, workgroupcategory string, deptno int, projectcode int,attendance int) STORED AS carbondata")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE integertypetable OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')""")
  }

  test("select empno from integertypetable") {
    checkAnswer(
      sql("select empno from integertypetable"),
      Seq(Row(11), Row(12), Row(13), Row(14), Row(15), Row(16), Row(17), Row(18), Row(19), Row(20)))
  }

  override def afterAll {
    sql("drop table integertypetable")
  }
} 
Example 135
Source File: TestSegmentReadingForMultiThreading.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.segmentreading

import java.util.concurrent.TimeUnit

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}

import org.apache.spark.sql.{CarbonUtils, Row}
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll



class TestSegmentReadingForMultiThreading extends QueryTest with BeforeAndAfterAll {

  override def beforeAll: Unit = {
    sql("DROP TABLE IF EXISTS carbon_table_MulTI_THread")
    sql(
      "CREATE TABLE carbon_table_MulTI_THread (empno int, empname String, designation String, doj " +
      "Timestamp, workgroupcategory int, workgroupcategoryname String, deptno int, deptname " +
      "String, projectcode int, projectjoindate Timestamp, projectenddate Timestamp,attendance " +
      "int,utilization int,salary int) STORED AS carbondata")
    sql(
      s"LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE carbon_table_MulTI_THread " +
      "OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
    sql(
      s"LOAD DATA LOCAL INPATH '$resourcesPath/data1.csv' INTO TABLE carbon_table_MulTI_THread " +
      "OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
    sql(
      s"LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE carbon_table_MulTI_THread " +
      "OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
    sql(
      s"LOAD DATA LOCAL INPATH '$resourcesPath/data1.csv' INTO TABLE carbon_table_MulTI_THread " +
      "OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')")
  }

  test("test multithreading for segment reading") {


    CarbonUtils.threadSet("carbon.input.segments.default.carbon_table_MulTI_THread", "1,2,3")
    val df = sql("select count(empno) from carbon_table_MulTI_THread")
    checkAnswer(df, Seq(Row(30)))

    val four = Future {
      CarbonUtils.threadSet("carbon.input.segments.default.carbon_table_MulTI_THread", "1,3")
      val df = sql("select count(empno) from carbon_table_MulTI_THread")
      checkAnswer(df, Seq(Row(20)))
    }

    val three = Future {
      CarbonUtils.threadSet("carbon.input.segments.default.carbon_table_MulTI_THread", "0,1,2")
      val df = sql("select count(empno) from carbon_table_MulTI_THread")
      checkAnswer(df, Seq(Row(30)))
    }


    val one = Future {
      CarbonUtils.threadSet("carbon.input.segments.default.carbon_table_MulTI_THread", "0,2")
      val df = sql("select count(empno) from carbon_table_MulTI_THread")
      checkAnswer(df, Seq(Row(20)))
    }

    val two = Future {
      CarbonUtils.threadSet("carbon.input.segments.default.carbon_table_MulTI_THread", "1")
      val df = sql("select count(empno) from carbon_table_MulTI_THread")
      checkAnswer(df, Seq(Row(10)))
    }
    Await.result(Future.sequence(Seq(one, two, three, four)), Duration(300, TimeUnit.SECONDS))
  }

  override def afterAll: Unit = {
    sql("DROP TABLE IF EXISTS carbon_table_MulTI_THread")
    CarbonUtils.threadUnset("carbon.input.segments.default.carbon_table_MulTI_THread")
  }
} 
Example 136
Source File: StoredAsCarbondataSuite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.sql.commands

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterEach

import org.apache.carbondata.core.constants.CarbonCommonConstants

class StoredAsCarbondataSuite extends QueryTest with BeforeAndAfterEach {
  override def beforeEach(): Unit = {
    sql("DROP TABLE IF EXISTS carbon_table")
    sql("DROP TABLE IF EXISTS tableSize3")
  }

  override def afterEach(): Unit = {
    sql("DROP TABLE IF EXISTS carbon_table")
    sql("DROP TABLE IF EXISTS tableSize3")
  }

  test("CARBONDATA-2262: Support the syntax of 'STORED AS CARBONDATA', upper case") {
    sql("CREATE TABLE carbon_table(key INT, value STRING) STORED AS CARBONDATA")
    sql("INSERT INTO carbon_table VALUES (28,'Bob')")
    checkAnswer(sql("SELECT * FROM carbon_table"), Seq(Row(28, "Bob")))
  }

  test("CARBONDATA-2262: Support the syntax of 'STORED AS carbondata', low case") {
    sql("CREATE TABLE carbon_table(key INT, value STRING) STORED AS carbondata")
    sql("INSERT INTO carbon_table VALUES (28,'Bob')")
    checkAnswer(sql("SELECT * FROM carbon_table"), Seq(Row(28, "Bob")))
  }

  test("CARBONDATA-2262: Support the syntax of 'STORED AS carbondata, get data size and index size after minor compaction") {
    sql("CREATE TABLE tableSize3 (empno INT, workgroupcategory STRING, deptno INT, projectcode INT, attendance INT) STORED AS carbondata")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE tableSize3 OPTIONS ('DELIMITER'= ',', 'QUOTECHAR'= '\"', 'FILEHEADER'='')""")
    sql("ALTER TABLE tableSize3 COMPACT 'minor'")
    checkExistence(sql("DESCRIBE FORMATTED tableSize3"), true, CarbonCommonConstants.TABLE_DATA_SIZE)
    checkExistence(sql("DESCRIBE FORMATTED tableSize3"), true, CarbonCommonConstants.TABLE_INDEX_SIZE)
    val res3 = sql("DESCRIBE FORMATTED tableSize3").collect()
      .filter(row => row.getString(0).contains(CarbonCommonConstants.TABLE_DATA_SIZE) ||
        row.getString(0).contains(CarbonCommonConstants.TABLE_INDEX_SIZE))
    assert(res3.length == 2)
    res3.foreach(row => assert(row.getString(1).trim.substring(0, 3).toDouble > 0))
  }

  test("CARBONDATA-2262: Don't Support the syntax of 'STORED AS 'carbondata''") {
    try {
      sql("CREATE TABLE carbon_table(key INT, value STRING) STORED AS 'carbondata'")
    } catch {
      case e: Exception =>
        assert(e.getMessage.contains("mismatched input"))
    }
  }

  test("CARBONDATA-2262: Don't Support the syntax of 'stored by carbondata'") {
    try {
      sql("CREATE TABLE carbon_table(key INT, value STRING) STORED BY carbondata")
    } catch {
      case e: Exception =>
        assert(e.getMessage.contains("mismatched input"))
    }
  }

  test("CARBONDATA-2262: Don't Support the syntax of 'STORED AS  ', null format") {
    try {
      sql("CREATE TABLE carbon_table(key INT, value STRING) STORED AS  ")
    } catch {
      case e: Exception =>
        assert(e.getMessage.contains("no viable alternative at input") ||
        e.getMessage.contains("mismatched input '<EOF>' expecting "))
    }
  }

  test("CARBONDATA-2262: Don't Support the syntax of 'STORED AS carbon'") {
    try {
      sql("CREATE TABLE carbon_table(key INT, value STRING) STORED AS carbon")
    } catch {
      case e: Exception =>
        assert(e.getMessage.contains("Operation not allowed: STORED AS with file format 'carbon'"))
    }
  }
} 
Example 137
Source File: SparkCarbonStoreTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.store

import org.apache.spark.sql.{CarbonEnv, Row}
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.metadata.datatype.DataTypes
import org.apache.carbondata.core.scan.expression.conditional.EqualToExpression
import org.apache.carbondata.core.scan.expression.{ColumnExpression, LiteralExpression}

class SparkCarbonStoreTest extends QueryTest with BeforeAndAfterAll {

  private var store: CarbonStore = _

  override def beforeAll {
    sql("DROP TABLE IF EXISTS t1")
    sql("CREATE TABLE t1 (" +
        "empno int, empname String, designation String, doj Timestamp, " +
        "workgroupcategory int, workgroupcategoryname String, deptno int, deptname String," +
        "projectcode int, projectjoindate Timestamp, projectenddate Timestamp," +
        "attendance int,utilization int,salary int)" +
        "STORED AS carbondata")
    sql(s"""LOAD DATA LOCAL INPATH '$resourcesPath/data.csv' INTO TABLE t1 OPTIONS('DELIMITER'= ',', 'QUOTECHAR'= '\"')""")

    store = new SparkCarbonStore(sqlContext.sparkSession)
  }

  test("test CarbonStore.get, compare projection result") {
    val table = CarbonEnv.getCarbonTable(None, "t1")(sqlContext.sparkSession)
    val rows = store.scan(table.getAbsoluteTableIdentifier, Seq("empno", "empname").toArray)
    val sparkResult: Array[Row] = sql("select empno, empname from t1").collect()
    sparkResult.zipWithIndex.foreach { case (r: Row, i: Int) =>
      val carbonRow = rows.next()
      assertResult(r.get(0))(carbonRow.getData()(0))
      assertResult(r.get(1))(carbonRow.getData()(1))
    }
    assert(!rows.hasNext)
  }

  test("test CarbonStore.get, compare projection and filter result") {
    val table = CarbonEnv.getCarbonTable(None, "t1")(sqlContext.sparkSession)
    val filter = new EqualToExpression(
      new ColumnExpression("empno", DataTypes.INT),
      new LiteralExpression(10, DataTypes.INT))
    val rows = store.scan(table.getAbsoluteTableIdentifier, Seq("empno", "empname").toArray, filter)
    val sparkResult: Array[Row] = sql("select empno, empname from t1 where empno = 10").collect()
    sparkResult.zipWithIndex.foreach { case (r: Row, i: Int) =>
      val carbonRow = rows.next()
      assertResult(r.get(0))(carbonRow.getData()(0))
      assertResult(r.get(1))(carbonRow.getData()(1))
    }
    assert(!rows.hasNext)
  }

  test("test CarbonStore.sql") {
    val rows = store.sql("select empno, empname from t1 where empno = 10")
    val sparkResult: Array[Row] = sql("select empno, empname from t1 where empno = 10").collect()
    sparkResult.zipWithIndex.foreach { case (r: Row, i: Int) =>
      val carbonRow = rows.next()
      assertResult(r.get(0))(carbonRow.getData()(0))
      assertResult(r.get(1))(carbonRow.getData()(1))
    }
    assert(!rows.hasNext)
  }

  override def afterAll {
    sql("DROP TABLE IF EXISTS t1")
  }
} 
Example 138
Source File: DropTableTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException

class DropTableTest extends QueryTest with BeforeAndAfterAll {

  test("test to drop parent table with all indexes") {
    sql("drop database if exists cd cascade")
    sql("create database cd")
    sql("show tables in cd").show()
    sql("create table cd.t1 (a string, b string, c string) STORED AS carbondata")
    sql("create index i1 on table cd.t1(c) AS 'carbondata'")
    sql("create index i2 on table cd.t1(c,b) AS 'carbondata'")
    sql("show tables in cd").show()
    sql("drop table cd.t1")
    assert(sql("show tables in cd").collect()
      .forall(row => row.getString(1) != "i2" && row != Row("cd", "i1", "false") && row != Row("cd", "t1", "false")))
  }


  

  test("test to drop index tables") {
    sql("drop database if exists cd cascade")
    sql("create database cd")
    sql("create table cd.t1 (a string, b string, c string) STORED AS carbondata")
    sql("create index i1 on table cd.t1(c) AS 'carbondata'")
    sql("create index i2 on table cd.t1(c,b) AS 'carbondata'")
    sql("show tables in cd").show()
    sql("drop index i1 on cd.t1")
    sql("drop index i2 on cd.t1")
    assert(sql("show tables in cd").collect()
      .forall(row => !row.getString(1).equals("i1") && !row.getString(1).equals("i2") && row.getString(1).equals("t1")))
    assert(sql("show indexes on cd.t1").collect().isEmpty)
  }

  test("test drop index command") {
    sql("drop table if exists testDrop")
    sql("create table testDrop (a string, b string, c string) STORED AS carbondata")
    val exception = intercept[MalformedCarbonCommandException] {
      sql("drop index indTestDrop on testDrop")
    }
    assert(exception.getMessage.contains("Index with name indtestdrop does not exist"))
    sql("drop table if exists testDrop")
  }
} 
Example 139
Source File: TestIndexWithIndexModelOnFirstColumnAndSortColumns.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties

class TestIndexWithIndexModelOnFirstColumnAndSortColumns extends QueryTest with BeforeAndAfterAll {

  var count1BeforeIndex : Array[Row] = null

  override def beforeAll {

    sql("drop table if exists seccust")
    sql("create table seccust (id string, c_custkey string, c_name string, c_address string, c_nationkey string, c_phone string,c_acctbal decimal, c_mktsegment string, c_comment string) " +
        "STORED AS carbondata TBLPROPERTIES ('table_blocksize'='128','SORT_COLUMNS'='c_custkey,c_name','NO_INVERTED_INDEX'='c_nationkey')")
    sql(s"""load data  inpath '${resourcesPath}/secindex/firstunique.csv' into table seccust options('DELIMITER'='|','QUOTECHAR'='"','FILEHEADER'='id,c_custkey,c_name,c_address,c_nationkey,c_phone,c_acctbal,c_mktsegment,c_comment')""")
    sql(s"""load data  inpath '${resourcesPath}/secindex/secondunique.csv' into table seccust options('DELIMITER'='|','QUOTECHAR'='\"','FILEHEADER'='id,c_custkey,c_name,c_address,c_nationkey,c_phone,c_acctbal,c_mktsegment,c_comment')""")
    count1BeforeIndex = sql("select * from seccust where id = '1' limit 1").collect()
    sql("create index sc_indx1 on table seccust(id) AS 'carbondata'")
  }

  test("Test secondry index on 1st column and with sort columns") {
    checkAnswer(sql("select count(*) from seccust where id = '1'"),Row(2))
  }

  override def afterAll {
    sql("drop table if exists orders")
  }
} 
Example 140
Source File: TestCarbonJoin.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import org.apache.spark.sql.test.util.QueryTest
import org.apache.spark.sql.{CarbonEnv, Row}
import org.scalatest.BeforeAndAfterAll

class TestCarbonJoin extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
  }

  test("test broadcast FilterPushDown with alias") {
    sql("DROP TABLE IF EXISTS table1")
    sql("DROP TABLE IF EXISTS ptable")
    sql("DROP TABLE IF EXISTS result")
    sql("create table if not exists table1 (ID string) STORED AS carbondata")
    sql("insert into table1 select 'animal'")
    sql("insert into table1 select 'person'")
    sql("create table ptable(pid string) stored as parquet")
    sql("insert into table ptable values('person')")

    val df2 = sql("select id as f91 from table1")
    df2.createOrReplaceTempView("tempTable_2")
    sql("select t1.f91 from tempTable_2 t1, ptable t2 where t1.f91 = t2.pid ").write.saveAsTable("result")
    checkAnswer(sql("select count(*) from result"), Seq(Row(1)))
    checkAnswer(sql("select * from result"), Seq(Row("person")))

    sql("DROP TABLE IF EXISTS result")
    sql("DROP TABLE IF EXISTS table1")
    sql("DROP TABLE IF EXISTS patble")
  }

  test("test broadcast FilterPushDown with alias with SI") {
    sql("drop index if exists cindex on ctable")
    sql("DROP TABLE IF EXISTS ctable")
    sql("DROP TABLE IF EXISTS ptable")
    sql("DROP TABLE IF EXISTS result")
    sql("create table if not exists ctable (type int, id1 string, id string) stored as " +
        "carbondata")
    sql("create index cindex on table ctable (id) AS 'carbondata'")
    sql("insert into ctable select 0, 'animal1', 'animal'")
    sql("insert into ctable select 1, 'person1', 'person'")
    sql("create table ptable(pid string) stored as parquet")
    sql("insert into table ptable values('person')")
    val carbonTable = CarbonEnv.getCarbonTable(Option("default"), "ctable")(sqlContext.sparkSession)
    val df2 = sql("select id as f91 from ctable")
    df2.createOrReplaceTempView("tempTable_2")
    sql("select t1.f91 from tempTable_2 t1, ptable t2 where t1.f91 = t2.pid ").write
      .saveAsTable("result")
    checkAnswer(sql("select count(*) from result"), Seq(Row(1)))
    checkAnswer(sql("select * from result"), Seq(Row("person")))
    sql("DROP TABLE IF EXISTS result")
    sql("drop index if exists cindex on ctable")
    sql("DROP TABLE IF EXISTS ctable")
    sql("DROP TABLE IF EXISTS patble")
  }
} 
Example 141
Source File: TestRegisterIndexCarbonTable.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import java.io.{File, IOException}

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.TestQueryExecutor
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants


class TestRegisterIndexCarbonTable extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("drop database if exists carbon cascade")
  }

  def restoreData(dblocation: String, tableName: String) = {
    val destination = dblocation + CarbonCommonConstants.FILE_SEPARATOR + tableName
    val source = dblocation+ "_back" + CarbonCommonConstants.FILE_SEPARATOR + tableName
    try {
      FileUtils.copyDirectory(new File(source), new File(destination))
      FileUtils.deleteDirectory(new File(source))
    } catch {
      case e : Exception =>
        throw new IOException("carbon table data restore failed.")
    } finally {

    }
  }
  def backUpData(dblocation: String, tableName: String) = {
    val source = dblocation + CarbonCommonConstants.FILE_SEPARATOR + tableName
    val destination = dblocation+ "_back" + CarbonCommonConstants.FILE_SEPARATOR + tableName
    try {
      FileUtils.copyDirectory(new File(source), new File(destination))
    } catch {
      case e : Exception =>
        throw new IOException("carbon table data backup failed.")
    }
  }
  test("register tables test") {
    val location = TestQueryExecutor.warehouse +
                           CarbonCommonConstants.FILE_SEPARATOR + "dbName"
    sql("drop database if exists carbon cascade")
    sql(s"create database carbon location '${location}'")
    sql("use carbon")
    sql("""create table carbon.carbontable (c1 string,c2 int,c3 string,c5 string) STORED AS carbondata""")
    sql("insert into carbontable select 'a',1,'aa','aaa'")
    sql("create index index_on_c3 on table carbontable (c3, c5) AS 'carbondata'")
    backUpData(location, "carbontable")
    backUpData(location, "index_on_c3")
    sql("drop table carbontable")
    restoreData(location, "carbontable")
    restoreData(location, "index_on_c3")
    sql("refresh table carbontable")
    sql("refresh table index_on_c3")
    checkAnswer(sql("select count(*) from carbontable"), Row(1))
    checkAnswer(sql("select c1 from carbontable"), Seq(Row("a")))
    sql("REGISTER INDEX TABLE index_on_c3 ON carbontable")
    assert(sql("show indexes on carbontable").collect().nonEmpty)
  }
  override def afterAll {
    sql("drop database if exists carbon cascade")
    sql("use default")
  }
} 
Example 142
Source File: TestLikeQueryWithIndex.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import org.apache.spark.sql.{CarbonDatasourceHadoopRelation, DataFrame, Row}
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll


class TestLikeQueryWithIndex extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    sql("drop table if exists TCarbon")

    sql("CREATE TABLE IF NOT EXISTS TCarbon(ID Int, country String, "+
          "name String, phonetype String, serialname String) "+
        "STORED AS carbondata"
    )
    var csvFilePath = s"$resourcesPath/secindex/secondaryIndexLikeTest.csv"

    sql(
      s"LOAD DATA LOCAL INPATH '" + csvFilePath + "' INTO TABLE " +
      s"TCarbon " +
      s"OPTIONS('DELIMITER'= ',')"

    )

    sql("create index insert_index on table TCarbon (name) AS 'carbondata'"
    )
  }

  test("select secondary index like query Contains") {
    val df = sql("select * from TCarbon where name like '%aaa1%'")
    secondaryIndexTableCheck(df,_.equalsIgnoreCase("TCarbon"))

    checkAnswer(
      sql("select * from TCarbon where name like '%aaa1%'"),
      Seq(Row(1, "china", "aaa1", "phone197", "A234"),
        Row(9, "china", "aaa1", "phone756", "A455"))
    )
  }

    test("select secondary index like query ends with") {
      val df = sql("select * from TCarbon where name like '%aaa1'")
      secondaryIndexTableCheck(df,_.equalsIgnoreCase("TCarbon"))

      checkAnswer(
        sql("select * from TCarbon where name like '%aaa1'"),
        Seq(Row(1, "china", "aaa1", "phone197", "A234"),
          Row(9, "china", "aaa1", "phone756", "A455"))
      )
    }

      test("select secondary index like query starts with") {
        val df = sql("select * from TCarbon where name like 'aaa1%'")
        secondaryIndexTableCheck(df, Set("insert_index","TCarbon").contains(_))

        checkAnswer(
          sql("select * from TCarbon where name like 'aaa1%'"),
          Seq(Row(1, "china", "aaa1", "phone197", "A234"),
            Row(9, "china", "aaa1", "phone756", "A455"))
        )
      }

  def secondaryIndexTableCheck(dataFrame:DataFrame,
      tableNameMatchCondition :(String) => Boolean): Unit ={
    dataFrame.queryExecution.sparkPlan.collect {
      case bcf: CarbonDatasourceHadoopRelation =>
        if(!tableNameMatchCondition(bcf.carbonTable.getTableUniqueName)){
          assert(true)
        }
    }
  }

  override def afterAll {
    sql("DROP INDEX if exists insert_index ON TCarbon")
    sql("drop table if exists TCarbon")
  }
} 
Example 143
Source File: TestIndexModelWithUnsafeColumnPage.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.secondaryindex

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.util.QueryTest
import org.scalatest.BeforeAndAfterAll

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.util.CarbonProperties

class TestIndexModelWithUnsafeColumnPage extends QueryTest with BeforeAndAfterAll {

  override def beforeAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.ENABLE_UNSAFE_COLUMN_PAGE, "true")
    sql("drop table if exists testSecondryIndex")
    sql("create table testSecondryIndex( a string,b string,c string) STORED AS carbondata")
    sql("insert into testSecondryIndex select 'babu','a','6'")
    sql("create index testSecondryIndex_IndexTable on table testSecondryIndex(b) AS 'carbondata'")
  }

  test("Test secondry index data count") {
    checkAnswer(sql("select count(*) from testSecondryIndex_IndexTable")
    ,Seq(Row(1)))
  }

  override def afterAll {
    sql("drop table if exists testIndexTable")
  }

} 
Example 144
Source File: RowStreamParserImp.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.streaming.parser

import java.text.SimpleDateFormat
import java.util

import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.processing.loading.ComplexDelimitersEnum
import org.apache.carbondata.processing.loading.constants.DataLoadProcessorConstants


class RowStreamParserImp extends CarbonStreamParser {

  var configuration: Configuration = null
  var isVarcharTypeMapping: Array[Boolean] = null
  var structType: StructType = null
  var encoder: ExpressionEncoder[Row] = null

  var timeStampFormat: SimpleDateFormat = null
  var dateFormat: SimpleDateFormat = null
  var complexDelimiters: util.ArrayList[String] = new util.ArrayList[String]()
  var serializationNullFormat: String = null

  override def initialize(configuration: Configuration,
      structType: StructType, isVarcharTypeMapping: Array[Boolean]): Unit = {
    this.configuration = configuration
    this.structType = structType
    this.encoder = RowEncoder.apply(this.structType).resolveAndBind()
    this.isVarcharTypeMapping = isVarcharTypeMapping

    this.timeStampFormat = new SimpleDateFormat(
      this.configuration.get(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT))
    this.dateFormat = new SimpleDateFormat(
      this.configuration.get(CarbonCommonConstants.CARBON_DATE_FORMAT))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_1"))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_2"))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_3"))
    this.complexDelimiters.add(ComplexDelimitersEnum.COMPLEX_DELIMITERS_LEVEL_4.value())
    this.serializationNullFormat =
      this.configuration.get(DataLoadProcessorConstants.SERIALIZATION_NULL_FORMAT)
  }

  override def parserRow(value: InternalRow): Array[Object] = {
    this.encoder.fromRow(value).toSeq.zipWithIndex.map { case (x, i) =>
      FieldConverter.objectToString(
        x, serializationNullFormat, complexDelimiters,
        timeStampFormat, dateFormat,
        isVarcharType = i < this.isVarcharTypeMapping.length && this.isVarcharTypeMapping(i),
        binaryCodec = null)
    } }.toArray

  override def close(): Unit = {
  }

} 
Example 145
Source File: DataLoader.scala    From variantsdwh   with Apache License 2.0 5 votes vote down vote up
package pl.edu.pw.ii.zsibio.dwh.benchmark

import com.typesafe.config.ConfigFactory
import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.{SparkConf, SparkContext}
import org.rogach.scallop.ScallopConf
import org.apache.kudu.spark.kudu._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, StructField, StructType}


object DataLoader {
  class RunConf(args:Array[String]) extends ScallopConf(args){

    val csvFile =opt[String]("csvFile",required = true, descr = "A CSV file to load" )
    val tableName =opt[String]("tableName",required = true, descr = "A table to load" )
    val storageType = opt[String]("storageType",required = true, descr = "Storage type parquet|orc|kudu|carbon")
    val dbName =opt[String]("dbName",required = true, descr = "Database name" )


    verify()
  }
  def main(args: Array[String]): Unit = {
    val runConf = new RunConf(args)
    val scConf = new SparkConf()
        .setAppName("DataLoader")
    val sc = new SparkContext(scConf)
    val sqlContext = new HiveContext(sc)


    if(runConf.storageType().toLowerCase() == "orc" || runConf.storageType().toLowerCase() == "parquet") {
      val df = sqlContext.read
        .format("com.databricks.spark.csv")
        .option("delimiter", "|")
        .option("nullValue","\\N")
        .option("inferSchema", "true") // Automatically infer data types
        .load(runConf.csvFile())
        .repartition(10)
      df.registerTempTable("temp_csv")
      sqlContext.sql(
        s"""
        |INSERT OVERWRITE TABLE ${runConf.dbName()}.${runConf.tableName()}

        |SELECT * FROM temp_csv
        """.stripMargin)
      }
    if(runConf.storageType().toLowerCase() == "kudu"){
      val confFile = ConfigFactory.load()
      val kuduMaster = confFile.getString("kudu.master.server")
      val kuduContext = new KuduContext(kuduMaster)
      val dfTarget = sqlContext.read.options(Map("kudu.master" -> kuduMaster,"kudu.table" -> runConf.tableName())).kudu
      val df = sqlContext.read
        .format("com.databricks.spark.csv")
        .option("delimiter", "|")
        .option("nullValue","\\N")
        .schema(dfTarget.schema)
        .load(runConf.csvFile())
        .repartition(10)
      kuduContext.upsertRows(df,runConf.tableName())
    }

  }

  private def synSchemas(inSchema:StructType, outSchema:StructType) = {

    val size = inSchema.fields.length
    val structFields = (0 to size - 1).map{
      i => StructField(outSchema.fields(i).name,inSchema.fields(i).dataType,outSchema.fields(i).nullable)
    }
    new StructType(structFields.toArray)

  }

} 
Example 146
Source File: SamplesGenerator.scala    From variantsdwh   with Apache License 2.0 5 votes vote down vote up
package pl.edu.pw.ii.zsibio.dwh.benchmark.generation

import java.lang

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.{Row, SQLContext}
import pl.edu.pw.ii.zsibio.dwh.benchmark.generation.model.GeneratedSample
import pl.edu.pw.ii.zsibio.dwh.benchmark.utils.Probability._

import scala.util.Random



    val population = sqlContext.sql(
      s"""
        |SELECT
        |population,
        |geo_id as id,
        |geo_country_name_en as countryName,
        |geo_region_name_en as region
        |FROM
        |${config.countryPopulation}
      """.stripMargin)

    val populationDist = population.map {
      case Row(population: Int, id: Int, countryName: String, region: String) =>
        (region, (population.toLong, id))
    }.groupByKey().collect().toMap

    def selectCountry(region: String): Int = {
      populationDist(region).selectWithProbability()
    }

    sc.parallelize(1 to n)
      .map((_, selectRegion()))
      .map(row => GeneratedSample(row._1, selectCountry(row._2.regionName), selectDisease(), row._2.afColumn))
  }

  def selectDisease() = {
    val r = Math.random()
    val ret: java.lang.Long = if (r <= 0.05)
      (Random.nextInt(60) * 100).toLong
    else
      new lang.Long(-1)
    ret
  }

  def selectRegion(): RegionConfig = {
    val percentages = Seq((config.africaConfig.percent, config.africaConfig)
      , (config.americasConfig.percent, config.americasConfig)
      , (config.europaConfig.percent, config.europaConfig)
      , (config.finnishConfig.percent, config.finnishConfig)
      , (config.southAsianConfig.percent, config.southAsianConfig)
      , (config.westAsianConfig.percent, config.westAsianConfig)
    )

    percentages.selectWithProbability()
  }

} 
Example 147
Source File: ManyToManyNormalJoin.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.manytomany

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable

object ManyToManyNormalJoin {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args:Array[String]): Unit = {
    val jsonPath = args(0)

    val sparkSession = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .getOrCreate()

    val jsonDf = sparkSession.read.json(jsonPath)

    val nGramWordCount = jsonDf.rdd.flatMap(r => {
      val actions = r.getAs[mutable.WrappedArray[Row]]("actions")

      val resultList = new mutable.MutableList[((Long, Long), Int)]

      actions.foreach(a => {
        val aValue = a.getAs[Long]("action")
        actions.foreach(b => {
          val bValue = b.getAs[Long]("action")
          if (aValue < bValue) {
            resultList.+=(((aValue, bValue), 1))
          }
        })
      })
      resultList.toSeq
    }).reduceByKey(_ + _)

    nGramWordCount.collect().foreach(r => {
      println(" - " + r)
    })
  }
} 
Example 148
Source File: ManyToManyNestedJoin.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.manytomany

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable

object ManyToManyNestedJoin {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args:Array[String]): Unit = {
    val jsonPath = args(0)

    val sparkSession = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .getOrCreate()

    val jsonDf = sparkSession.read.json(jsonPath)

    val nGramWordCount = jsonDf.rdd.flatMap(r => {
      val actions = r.getAs[mutable.WrappedArray[Row]]("actions")

      val resultList = new mutable.MutableList[(Long, NestedCount)]

      actions.foreach(a => {
        val aValue = a.getAs[Long]("action")
        val aNestedCount = new NestedCount
        actions.foreach(b => {
          val bValue = b.getAs[Long]("action")
          if (aValue < bValue) {
            aNestedCount.+=(bValue, 1)
          }
        })
        resultList.+=((aValue, aNestedCount))
      })
      resultList.toSeq
    }).reduceByKey((a, b) => a + b)

      //.reduceByKey(_ + _)

    nGramWordCount.collect().foreach(r => {
      println(" - " + r)
    })
  }
}


//1,2
//1,3
//1,4

//1 (2, 3, 4)

class NestedCount() extends Serializable{

  val map = new mutable.HashMap[Long, Long]()

  def += (key:Long, count:Long): Unit = {
    val currentValue = map.getOrElse(key, 0l)
    map.put(key, currentValue + count)
  }

  def + (other:NestedCount): NestedCount = {
    val result = new NestedCount

    other.map.foreach(r => {
      result.+=(r._1, r._2)
    })
    this.map.foreach(r => {
      result.+=(r._1, r._2)
    })
    result
  }

  override def toString(): String = {
    val stringBuilder = new StringBuilder
    map.foreach(r => {
      stringBuilder.append("(" + r._1 + "," + r._2 + ")")
    })
    stringBuilder.toString()
  }
} 
Example 149
Source File: JsonNestedExample.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}

import scala.collection.mutable

object JsonNestedExample {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val isLocal = args(0).equalsIgnoreCase("l")
    val jsonPath = args(1)
    val outputTableName = args(2)

    val sparkSession = if (isLocal) {
      SparkSession.builder
        .master("local")
        .appName("my-spark-app")
        .config("spark.some.config.option", "config-value")
        .config("spark.driver.host","127.0.0.1")
        .config("spark.sql.parquet.compression.codec", "gzip")
        .enableHiveSupport()
        .getOrCreate()
    } else {
      SparkSession.builder
        .appName("my-spark-app")
        .config("spark.some.config.option", "config-value")
        .enableHiveSupport()
        .getOrCreate()
    }
    println("---")

    val jsonDf = sparkSession.read.json(jsonPath)

    val localJsonDf = jsonDf.collect()

    println("--Df")
    jsonDf.foreach(row => {
      println("row:" + row)
    })
    println("--local")
    localJsonDf.foreach(row => {
      println("row:" + row)
    })

    jsonDf.createOrReplaceTempView("json_table")

    println("--Tree Schema")
    jsonDf.schema.printTreeString()
    println("--")
    jsonDf.write.saveAsTable(outputTableName)

    sparkSession.sqlContext.sql("select * from " + outputTableName).take(10).foreach(println)

    println("--")
    
    sparkSession.stop()
  }

  def populatedFlattedHashMap(row:Row,
                              schema:StructType,
                              fields:Array[StructField],
                              flattedMap:mutable.HashMap[(String, DataType), mutable.MutableList[Any]],
                              parentFieldName:String): Unit = {
    fields.foreach(field => {

      println("field:" + field.dataType)
      if (field.dataType.isInstanceOf[ArrayType]) {
        val elementType = field.dataType.asInstanceOf[ArrayType].elementType
        if (elementType.isInstanceOf[StructType]) {
          val childSchema = elementType.asInstanceOf[StructType]

          val childRow = Row.fromSeq(row.getAs[mutable.WrappedArray[Any]](field.name).toSeq)

          populatedFlattedHashMap(childRow, childSchema, childSchema.fields, flattedMap, parentFieldName + field.name + ".")
        }
      } else {
        val fieldList = flattedMap.getOrElseUpdate((parentFieldName + field.name, field.dataType), new mutable.MutableList[Any])
        fieldList.+=:(row.getAs[Any](schema.fieldIndex(field.name)))
      }

    })
  }
} 
Example 150
Source File: NestedTableExample.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object NestedTableExample {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)
    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 151
Source File: PopulateHiveTable.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}


object PopulateHiveTable {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .config("spark.sql.parquet.compression.codec", "gzip")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    populated1Df.repartition(2).write.saveAsTable("nested_populated")

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 152
Source File: WrapperTrait.scala    From sparker   with GNU General Public License v3.0 5 votes vote down vote up
package SparkER.Wrappers

import SparkER.DataStructures.{KeyValue, MatchingEntities, Profile}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row

import scala.collection.mutable.MutableList


  def rowToAttributes(columnNames: Array[String], row: Row, explodeInnerFields: Boolean = false, innerSeparator: String = ","): MutableList[KeyValue] = {
    val attributes: MutableList[KeyValue] = new MutableList()
    for (i <- 0 to row.size - 1) {
      try {
        val value = row(i)
        val attributeKey = columnNames(i)

        if (value != null) {
          value match {
            case listOfAttributes: Iterable[Any] =>
              listOfAttributes map {
                attributeValue =>
                  attributes += KeyValue(attributeKey, attributeValue.toString)
              }
            case stringAttribute: String =>
              if (explodeInnerFields) {
                stringAttribute.split(innerSeparator) map {
                  attributeValue =>
                    attributes += KeyValue(attributeKey, attributeValue)
                }
              }
              else {
                attributes += KeyValue(attributeKey, stringAttribute)
              }
            case singleAttribute =>
              attributes += KeyValue(attributeKey, singleAttribute.toString)
          }
        }
      }
      catch {
        case e: Throwable => println(e)
      }
    }
    attributes
  }
} 
Example 153
Source File: WrapperTrait.scala    From sparker   with GNU General Public License v3.0 5 votes vote down vote up
package Wrappers

import DataStructures.{KeyValue, MatchingEntities, Profile}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row

import scala.collection.mutable.MutableList


  def rowToAttributes(columnNames: Array[String], row: Row, explodeInnerFields: Boolean = false, innerSeparator: String = ","): MutableList[KeyValue] = {
    val attributes: MutableList[KeyValue] = new MutableList()
    for (i <- 0 to row.size - 1) {
      try {
        val value = row(i)
        val attributeKey = columnNames(i)

        if (value != null) {
          value match {
            case listOfAttributes: Iterable[Any] =>
              listOfAttributes map {
                attributeValue =>
                  attributes += KeyValue(attributeKey, attributeValue.toString)
              }
            case stringAttribute: String =>
              if (explodeInnerFields) {
                stringAttribute.split(innerSeparator) map {
                  attributeValue =>
                    attributes += KeyValue(attributeKey, attributeValue)
                }
              }
              else {
                attributes += KeyValue(attributeKey, stringAttribute)
              }
            case singleAttribute =>
              attributes += KeyValue(attributeKey, singleAttribute.toString)
          }
        }
      }
      catch {
        case e: Throwable => println(e)
      }
    }
    attributes
  }
} 
Example 154
Source File: WrapperTrait.scala    From sparker   with GNU General Public License v3.0 5 votes vote down vote up
package Wrappers

import DataStructures.{KeyValue, MatchingEntities, Profile}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row

import scala.collection.mutable.MutableList


  def rowToAttributes(columnNames : Array[String], row : Row, explodeInnerFields:Boolean = false, innerSeparator : String = ",") : MutableList[KeyValue] = {
    val attributes: MutableList[KeyValue] = new MutableList()
    for(i <- 0 to row.size-1){
      try{
        val value = row(i)
        val attributeKey = columnNames(i)

        if(value != null){
          value match {
            case listOfAttributes : Iterable[Any] =>
              listOfAttributes map {
                attributeValue =>
                  attributes += KeyValue(attributeKey, attributeValue.toString)
              }
            case stringAttribute : String =>
              if(explodeInnerFields){
                stringAttribute.split(innerSeparator) map {
                  attributeValue =>
                    attributes += KeyValue(attributeKey, attributeValue)
                }
              }
              else {
                attributes += KeyValue(attributeKey, stringAttribute)
              }
            case singleAttribute =>
              attributes += KeyValue(attributeKey, singleAttribute.toString)
          }
        }
      }
      catch{
        case e : Throwable => println(e)
      }
    }
    attributes
  }
} 
Example 155
Source File: SavingStream.scala    From cuesheet   with Apache License 2.0 5 votes vote down vote up
package com.kakao.cuesheet.convert

import com.kakao.mango.concurrent.{NamedExecutors, RichExecutorService}
import com.kakao.mango.text.ThreadSafeDateFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.dstream.DStream

import java.util.concurrent.{Future => JFuture}
import scala.reflect.runtime.universe.TypeTag

object SavingStream {
  val yyyyMMdd = ThreadSafeDateFormat("yyyy-MM-dd")
  val hh = ThreadSafeDateFormat("HH")
  val mm = ThreadSafeDateFormat("mm")
  val m0 = (ms: Long) => mm(ms).charAt(0) + "0"
}


  @transient var executor: RichExecutorService = _

  def ex: RichExecutorService = {
    if (executor == null) {
      this.synchronized {
        if (executor == null) {
          executor = new RichExecutorService(es.get())
        }
      }
    }
    executor
  }

  def saveAsPartitionedTable(table: String, path: String, format: String = "orc")(toPartition: Time => Seq[(String, String)]): Unit = {
    stream.foreachRDD { (rdd, time) =>
      ex.submit {
        toDF(rdd).appendToExternalTablePartition(table, path, format, toPartition(time): _*)
      }
    }
  }

  def saveAsDailyPartitionedTable(table: String, path: String, dateColumn: String = "date", format: String = "orc"): Unit = {
    saveAsPartitionedTable(table, path, format) { time =>
      val ms = time.milliseconds
      Seq(dateColumn -> yyyyMMdd(ms))
    }
  }

  def saveAsHourlyPartitionedTable(table: String, path: String, dateColumn: String = "date", hourColumn: String = "hour", format: String = "orc"): Unit = {
    saveAsPartitionedTable(table, path, format) { time =>
      val ms = time.milliseconds
      Seq(dateColumn -> yyyyMMdd(ms), hourColumn -> hh(ms))
    }
  }

  def saveAsTenMinutelyPartitionedTable(table: String, path: String, dateColumn: String = "date", hourColumn: String = "hour", minuteColumn: String = "minute", format: String = "orc"): Unit = {
    saveAsPartitionedTable(table, path, format) { time =>
      val ms = time.milliseconds
      Seq(dateColumn -> yyyyMMdd(ms), hourColumn -> hh(ms), minuteColumn -> m0(ms))
    }
  }

  def saveAsMinutelyPartitionedTable(table: String, path: String, dateColumn: String = "date", hourColumn: String = "hour", minuteColumn: String = "minute", format: String = "orc"): Unit = {
    saveAsPartitionedTable(table, path, format) { time =>
      val ms = time.milliseconds
      Seq(dateColumn -> yyyyMMdd(ms), hourColumn -> hh(ms), minuteColumn -> mm(ms))
    }
  }

}

class ProductStream[T <: Product : TypeTag](stream: DStream[T])(implicit ctx: HiveContext, es: ExecutorSupplier) extends SavingStream[T](stream) {
  override def toDF(rdd: RDD[T]) = ctx.createDataFrame(rdd)
}

class JsonStream(stream: DStream[String])(implicit ctx: HiveContext, es: ExecutorSupplier) extends SavingStream[String](stream) {
  override def toDF(rdd: RDD[String]) = ctx.read.json(rdd)
}

class MapStream[T](stream: DStream[Map[String, T]])(implicit ctx: HiveContext, es: ExecutorSupplier) extends SavingStream[Map[String, T]](stream) {
  import com.kakao.mango.json._

  override def toDF(rdd: RDD[Map[String, T]]) = ctx.read.json(rdd.map(toJson))
}

class RowStream(stream: DStream[Row])(implicit ctx: HiveContext, es: ExecutorSupplier, schema: StructType) extends SavingStream[Row](stream) {
  override def toDF(rdd: RDD[Row]): DataFrame = ctx.createDataFrame(rdd, schema)
} 
Example 156
Source File: MemsqlRDD.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.memsql.spark.SQLGen.VariableList
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}

case class MemsqlRDD(query: String,
                     variables: VariableList,
                     options: MemsqlOptions,
                     schema: StructType,
                     expectedOutput: Seq[Attribute],
                     @transient val sc: SparkContext)
    extends RDD[Row](sc, Nil) {

  override protected def getPartitions: Array[Partition] =
    MemsqlQueryHelpers.GetPartitions(options, query, variables)

  override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = {
    var closed                     = false
    var rs: ResultSet              = null
    var stmt: PreparedStatement    = null
    var conn: Connection           = null
    var partition: MemsqlPartition = rawPartition.asInstanceOf[MemsqlPartition]

    def tryClose(name: String, what: AutoCloseable): Unit = {
      try {
        if (what != null) { what.close() }
      } catch {
        case e: Exception => logWarning(s"Exception closing $name", e)
      }
    }

    def close(): Unit = {
      if (closed) { return }
      tryClose("resultset", rs)
      tryClose("statement", stmt)
      tryClose("connection", conn)
      closed = true
    }

    context.addTaskCompletionListener { context =>
      close()
    }

    conn = JdbcUtils.createConnectionFactory(partition.connectionInfo)()
    stmt = conn.prepareStatement(partition.query)
    JdbcHelpers.fillStatement(stmt, partition.variables)
    rs = stmt.executeQuery()

    var rowsIter = JdbcUtils.resultSetToRows(rs, schema)

    if (expectedOutput.nonEmpty) {
      val schemaDatatypes   = schema.map(_.dataType)
      val expectedDatatypes = expectedOutput.map(_.dataType)

      if (schemaDatatypes != expectedDatatypes) {
        val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map {
          case ((_: StringType, _: NullType), _)     => ((_: Row) => null)
          case ((_: ShortType, _: BooleanType), i)   => ((r: Row) => r.getShort(i) != 0)
          case ((_: IntegerType, _: BooleanType), i) => ((r: Row) => r.getInt(i) != 0)
          case ((_: LongType, _: BooleanType), i)    => ((r: Row) => r.getLong(i) != 0)

          case ((l, r), i) => {
            options.assert(l == r, s"MemsqlRDD: unable to encode ${l} into ${r}")
            ((r: Row) => r.get(i))
          }
        }

        rowsIter = rowsIter
          .map(row => Row.fromSeq(columnEncoders.map(_(row))))
      }
    }

    CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close)
  }

} 
Example 157
Source File: DeltaLoad.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.algo

import com.adidas.analytics.algo.DeltaLoad._
import com.adidas.analytics.algo.core.Algorithm
import com.adidas.analytics.algo.shared.DateComponentDerivation
import com.adidas.analytics.config.DeltaLoadConfiguration.PartitionedDeltaLoadConfiguration
import com.adidas.analytics.util.DataFrameUtils._
import com.adidas.analytics.util._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.slf4j.{Logger, LoggerFactory}


  private def getUpsertRecords(deltaRecords: Dataset[Row], resultColumns: Seq[String]): Dataset[Row] = {
    // Create partition window - Partitioning by delta records logical key (i.e. technical key of active records)
    val partitionWindow = Window
      .partitionBy(businessKey.map(col): _*)
      .orderBy(technicalKey.map(component => col(component).desc): _*)

    // Ranking & projection
    val rankedDeltaRecords = deltaRecords
      .withColumn(rankingColumnName, row_number().over(partitionWindow))
      .filter(upsertRecordsModesFilterFunction)

    rankedDeltaRecords
      .filter(rankedDeltaRecords(rankingColumnName) === 1)
      .selectExpr(resultColumns: _*)
  }

  protected def withDatePartitions(spark: SparkSession, dfs: DFSWrapper, dataFrames: Vector[DataFrame]): Vector[DataFrame] = {
    logger.info("Adding partitioning information if needed")
    try {
      dataFrames.map { df =>
        if (df.columns.toSeq.intersect(targetPartitions) != targetPartitions){
          df.transform(withDateComponents(partitionSourceColumn, partitionSourceColumnFormat, targetPartitions))
        }
        else df
      }
    } catch {
      case e: Throwable =>
        logger.error("Cannot add partitioning information for data frames.", e)
        //TODO: Handle failure case properly
        throw new RuntimeException("Unable to transform data frames.", e)
    }
  }
}


object DeltaLoad {

  private val logger: Logger = LoggerFactory.getLogger(getClass)

  def apply(spark: SparkSession, dfs: DFSWrapper, configLocation: String): DeltaLoad = {
    new DeltaLoad(spark, dfs, configLocation)
  }
} 
Example 158
Source File: PartitionHelpers.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.algo.core

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}


trait PartitionHelpers {

  protected def getDistinctPartitions(outputDataFrame: DataFrame, targetPartitions: Seq[String]): Dataset[Row] = {
    val targetPartitionsColumns: Seq[Column] = targetPartitions.map(partitionString => col(partitionString))

    outputDataFrame.select(targetPartitionsColumns: _*).distinct
  }

  protected def getParameterValue(row: Row, partitionString: String): String =
    createParameterValue(row.get(row.fieldIndex(partitionString)))

  protected def createParameterValue(partitionRawValue: Any): String =
    partitionRawValue match {
      case value: java.lang.Short => value.toString
      case value: java.lang.Integer => value.toString
      case value: scala.Predef.String => "'" + value + "'"
      case null => throw new Exception("Partition Value is null. No support for null partitions!")
      case value => throw new Exception("Unsupported partition DataType: " + value.getClass)
    }
} 
Example 159
Source File: DeltaLoadConfiguration.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.config

import com.adidas.analytics.algo.core.Algorithm.{ReadOperation, SafeWriteOperation, UpdateStatisticsOperation}
import com.adidas.analytics.config.shared.{ConfigurationContext, DateComponentDerivationConfiguration, MetadataUpdateStrategy}
import com.adidas.analytics.util.DataFormat.ParquetFormat
import com.adidas.analytics.util.{DataFormat, InputReader, LoadMode, OutputWriter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SparkSession}


trait DeltaLoadConfiguration extends ConfigurationContext
  with UpdateStatisticsOperation
  with MetadataUpdateStrategy {

  protected val activeRecordsTable: String = configReader.getAs[String]("active_records_table_lake")
  protected val deltaRecordsTable: Option[String] = configReader.getAsOption[String]("delta_records_table_lake")
  protected val deltaRecordsFilePath: Option[String] = configReader.getAsOption[String]("delta_records_file_path")

  protected val businessKey: Seq[String] = configReader.getAsSeq[String]("business_key")
  protected val technicalKey: Seq[String] = configReader.getAsSeq[String]("technical_key")

  protected val rankingColumnName: String = "DELTA_LOAD_RANK"
  protected val recordModeColumnName: String = "recordmode"
  protected val upsertRecordModes: Seq[String] = Seq("", "N")
  protected val upsertRecordsModesFilterFunction: Row => Boolean = { row: Row =>
    var recordmode = ""
    try {
      recordmode = row.getAs[String](recordModeColumnName)
    } catch {
      case _ => recordmode = row.getAs[String](recordModeColumnName.toUpperCase)
    }
    recordmode == null || recordmode == "" || recordmode == "N"
  }
}


object DeltaLoadConfiguration {

  trait PartitionedDeltaLoadConfiguration extends DeltaLoadConfiguration with DateComponentDerivationConfiguration
    with ReadOperation with SafeWriteOperation {

    protected def spark: SparkSession

    override protected val targetPartitions: Seq[String] = configReader.getAsSeq[String]("target_partitions")
    override protected val partitionSourceColumn: String = configReader.getAs[String]("partition_column")
    override protected val partitionSourceColumnFormat: String = configReader.getAs[String]("partition_column_format")

    private val targetSchema: StructType = spark.table(activeRecordsTable).schema

    override protected val readers: Vector[InputReader] = Vector(
      createDeltaInputReader(deltaRecordsFilePath, deltaRecordsTable),
      InputReader.newTableReader(table = activeRecordsTable)
    )

    override protected val writer: OutputWriter.AtomicWriter = OutputWriter.newTableLocationWriter(
      table = activeRecordsTable,
      format = ParquetFormat(Some(targetSchema)),
      targetPartitions = targetPartitions,
      metadataConfiguration = getMetaDataUpdateStrategy(activeRecordsTable, targetPartitions),
      loadMode = LoadMode.OverwritePartitionsWithAddedColumns
    )
  }

  private def createDeltaInputReader(deltaRecordsFilePath: Option[String], deltaRecordsTable: Option[String]): InputReader = {
    def createInputReaderByPath: InputReader = {
      deltaRecordsFilePath.fold {
        throw new RuntimeException("Unable to create a reader for the delta table: neither delta records path not delta table name is defined")
      } {
        location => InputReader.newFileSystemReader(s"$location*.parquet", DataFormat.ParquetFormat())
      }
    }

    deltaRecordsTable.fold(createInputReaderByPath)(tableName => InputReader.newTableReader(tableName))
  }
} 
Example 160
Source File: DataFrameUtils.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.util

import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, functions}
import org.slf4j.{Logger, LoggerFactory}


object DataFrameUtils {

  private val logger: Logger = LoggerFactory.getLogger(getClass)

  type FilterFunction = Row => Boolean

  type PartitionCriteria = Seq[(String, String)]

  def mapPartitionsToDirectories(partitionCriteria: PartitionCriteria): Seq[String] = {
    partitionCriteria.map {
      case (columnName, columnValue) => s"$columnName=$columnValue"
    }
  }

  def buildPartitionsCriteriaMatcherFunc(multiplePartitionsCriteria: Seq[PartitionCriteria], schema: StructType): FilterFunction = {
    val targetPartitions = multiplePartitionsCriteria.flatten.map(_._1).toSet
    val fieldNameToMatchFunctionMapping = schema.fields.filter {
      case StructField(name, _, _, _) => targetPartitions.contains(name)
    }.map {
      case StructField(name, _: ByteType, _, _)    => name -> ((r: Row, value: String) => r.getAs[Byte](name)    == value.toByte)
      case StructField(name, _: ShortType, _, _)   => name -> ((r: Row, value: String) => r.getAs[Short](name)   == value.toShort)
      case StructField(name, _: IntegerType, _, _) => name -> ((r: Row, value: String) => r.getAs[Int](name)     == value.toInt)
      case StructField(name, _: LongType, _, _)    => name -> ((r: Row, value: String) => r.getAs[Long](name)    == value.toLong)
      case StructField(name, _: FloatType, _, _)   => name -> ((r: Row, value: String) => r.getAs[Float](name)   == value.toFloat)
      case StructField(name, _: DoubleType, _, _)  => name -> ((r: Row, value: String) => r.getAs[Double](name)  == value.toDouble)
      case StructField(name, _: BooleanType, _, _) => name -> ((r: Row, value: String) => r.getAs[Boolean](name) == value.toBoolean)
      case StructField(name, _: StringType, _, _)  => name -> ((r: Row, value: String) => r.getAs[String](name)  == value)
    }.toMap

    def convertPartitionCriteriaToFilterFunctions(partitionCriteria: PartitionCriteria): Seq[FilterFunction] = partitionCriteria.map {
      case (name, value) => (row: Row) => fieldNameToMatchFunctionMapping(name)(row, value)
    }

    def joinSinglePartitionFilterFunctionsWithAnd(partitionFilterFunctions: Seq[FilterFunction]): FilterFunction =
      partitionFilterFunctions
        .reduceOption((predicate1, predicate2) => (row: Row) => predicate1(row) && predicate2(row))
        .getOrElse((_: Row) => false)

    multiplePartitionsCriteria
      .map(convertPartitionCriteriaToFilterFunctions)
      .map(joinSinglePartitionFilterFunctionsWithAnd)
      .reduceOption((predicate1, predicate2) => (row: Row) => predicate1(row) || predicate2(row))
      .getOrElse((_: Row) => false)
  }


  implicit class DataFrameHelper(df: DataFrame) {

    def collectPartitions(targetPartitions: Seq[String]): Seq[PartitionCriteria] = {
      logger.info(s"Collecting unique partitions for partitions columns (${targetPartitions.mkString(", ")})")
      val partitions = df.selectExpr(targetPartitions: _*).distinct().collect()

      partitions.map { row =>
        targetPartitions.map { columnName =>
          Option(row.getAs[Any](columnName)) match {
            case Some(columnValue) => columnName -> columnValue.toString
            case None => throw new RuntimeException(s"Partition column '$columnName' contains null value")
          }
        }
      }
    }

    def addMissingColumns(targetSchema: StructType): DataFrame = {
      val dataFieldsSet = df.schema.fieldNames.toSet
      val selectColumns = targetSchema.fields.map { field =>
        if (dataFieldsSet.contains(field.name)) {
          functions.col(field.name)
        } else {
          functions.lit(null).cast(field.dataType).as(field.name)
        }
      }
      df.select(selectColumns: _*)
    }

    def isEmpty: Boolean = df.head(1).isEmpty

    def nonEmpty: Boolean = df.head(1).nonEmpty
  }
} 
Example 161
Source File: TestUtils.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.utils

import org.apache.spark.sql.functions.{col, count, lit}
import org.apache.spark.sql.{DataFrame, Row}

object TestUtils {

  implicit class ExtendedDataFrame(df: DataFrame) {

    def hasDiff(anotherDf: DataFrame): Boolean = {
      def printDiff(incoming: Boolean)(row: Row): Unit = {
        if (incoming) print("+ ") else print("- ")
        println(row)
      }

      val groupedDf = df.groupBy(df.columns.map(col): _*).agg(count(lit(1))).collect().toSet
      val groupedAnotherDf = anotherDf.groupBy(anotherDf.columns.map(col): _*).agg(count(lit(1))).collect().toSet

      groupedDf.diff(groupedAnotherDf).foreach(printDiff(incoming = true))
      groupedAnotherDf.diff(groupedDf).foreach(printDiff(incoming = false))

      groupedDf.diff(groupedAnotherDf).nonEmpty || groupedAnotherDf.diff(groupedDf).nonEmpty
    }
  }
} 
Example 162
Source File: SparkRecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.SparkRecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class SparkRecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 163
Source File: RecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.RecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class RecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 164
Source File: RowSyntax.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics.schema

import org.apache.spark.sql
import org.apache.spark.sql.Row
import org.apache.spark.sql.hyperloglog.MergeHyperLogLog

object RowSyntax {

  sealed trait ColumnType

  trait IntColumn extends ColumnType
  trait LongColumn extends ColumnType
  trait StringColumn extends ColumnType
  trait BinaryColumn extends ColumnType
  trait StringArrayColumn extends ColumnType
  trait HLLColumn extends ColumnType

  sealed trait ColumnReader[C <: ColumnType] { self =>
    type Out

    def read(row: sql.Row)(idx: Int): Out

    def map[Out1](f: Out => Out1): ColumnReader[C] {type Out = Out1} =
      new ColumnReader[C] {
        type Out = Out1

        def read(row: Row)(idx: Int): Out = {
          f(self.read(row)(idx))
        }
      }
  }

  implicit class RowOps(val row: Row) extends AnyVal {
    def read[C <: ColumnType](idx: Int)(implicit reader: ColumnReader[C]): reader.Out = {
      reader.read(row)(idx)
    }
  }

  class IntReader[C <: ColumnType] extends ColumnReader[C] {
    type Out = Int
    def read(row: Row)(idx: Int): Out = row.getInt(idx)
  }

  class LongReader[C <: ColumnType] extends ColumnReader[C] {
    type Out = Long
    def read(row: Row)(idx: Int): Out = row.getLong(idx)
  }

  class StringReader[C <: ColumnType] extends ColumnReader[C] {
    type Out = String
    def read(row: Row)(idx: Int): Out = row(idx) match {
      case null => ""
      case str: String => str
      case arr: Array[_] => new String(arr.asInstanceOf[Array[Byte]])
    }
  }

  class StringArrayReader[C <: ColumnType] extends ColumnReader[C] {
    type Out = Array[String]
    def read(row: Row)(idx: Int): Out = row(idx) match {
      case null => Array.empty[String]
      case arr: Array[_] => arr.map(_.toString)
    }
  }

  class BinaryReader[C <: ColumnType] extends ColumnReader[C] {
    type Out = Array[Byte]

    def read(row: Row)(idx: Int): Out = {
      row.getAs[Array[Byte]](idx)
    }
  }

  // Implicit Column Readers

  implicit val intReader = new IntReader[IntColumn]
  implicit val longReader = new LongReader[LongColumn]
  implicit val stringReader = new StringReader[StringColumn]
  implicit val stringArrayReader = new StringArrayReader[StringArrayColumn]
  implicit val binaryReader = new BinaryReader[BinaryColumn]

  implicit val cardinalityReader = new BinaryReader[HLLColumn] map { bytes =>
    MergeHyperLogLog.readHLLWritable(bytes).get()
  }

} 
Example 165
Source File: ActivityLog.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics.schema

import com.adroll.cantor.HLLCounter
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

case class ActivityLog(
  adId: String,
  siteId: String,
  cookiesHLL: HLLCounter,
  impressions: Long,
  clicks: Long
)

object ActivityLog { self =>

  val ad_id               = "ad_id"
  val site_id             = "site_id"
  val cookies_hll         = "cookies_hll"
  val impressions         = "impressions"
  val clicks              = "clicks"

  object Schema extends SchemaDefinition {
    val ad_id        = structField(self.ad_id,       StringType)
    val site_id      = structField(self.site_id,     StringType)
    val cookies_hll  = structField(self.cookies_hll, BinaryType)
    val impressions  = structField(self.impressions, LongType)
    val clicks       = structField(self.clicks,      LongType)
  }

  val schema: StructType = StructType(Schema.fields)

  import RowSyntax._

  def parse(row: Row): ActivityLog = {
    ActivityLog(
      row.read[StringColumn](0),
      row.read[StringColumn](1),
      row.read[HLLColumn](2),
      row.read[LongColumn](3),
      row.read[LongColumn](4)
    )
  }

} 
Example 166
Source File: ImpressionLog.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics.schema

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

case class ImpressionLog(
  adId: String,
  siteId: String,
  cookieId: String,
  impressions: Long,
  clicks: Long,
  segments: Array[String]
)

object ImpressionLog { self =>

  val ad_id               = "ad_id"
  val site_id             = "site_id"
  val cookie_id           = "cookie_id"
  val impressions         = "impressions"
  val clicks              = "clicks"
  val segments            = "segments"

  object Schema extends SchemaDefinition {
    val ad_id        = structField(self.ad_id,       StringType)
    val site_id      = structField(self.site_id,     StringType)
    val cookie_id    = structField(self.cookie_id,   StringType)
    val impressions  = structField(self.impressions, LongType)
    val clicks       = structField(self.clicks,      LongType)
    val segments     = structField(self.segments,    ArrayType(StringType))
  }

  val schema: StructType = StructType(Schema.fields)

  import RowSyntax._

  def parse(row: Row): ImpressionLog = ImpressionLog(
    row.read[StringColumn](0),
    row.read[StringColumn](1),
    row.read[StringColumn](2),
    row.read[LongColumn](3),
    row.read[LongColumn](4),
    row.read[StringArrayColumn](5)
  )

} 
Example 167
Source File: SegmentLog.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics.schema

import com.adroll.cantor.HLLCounter
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

case class SegmentLog(
  segment: String,
  cookiesHLL: HLLCounter,
  impressions: Long,
  clicks: Long
)

object SegmentLog { self =>

  val segment             = "segment"
  val cookies_hll         = "cookies_hll"
  val impressions         = "impressions"
  val clicks              = "clicks"

  object Schema extends SchemaDefinition {
    val segment      = structField(self.segment,     StringType)
    val cookies_hll  = structField(self.cookies_hll, BinaryType)
    val impressions  = structField(self.impressions, LongType)
    val clicks       = structField(self.clicks,      LongType)
  }

  val schema: StructType = StructType(Schema.fields)

  import RowSyntax._

  def parse(row: Row): SegmentLog = {
    SegmentLog(
      row.read[StringColumn](0),
      row.read[HLLColumn](1),
      row.read[LongColumn](2),
      row.read[LongColumn](3)
    )
  }

} 
Example 168
Source File: AudienceAnalyticsSpec.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics

import com.collective.analytics.schema.ImpressionLog
import org.apache.spark.sql.Row
import org.scalatest.{FlatSpec, ShouldMatchers}


class SparkAudienceAnalyticsSpec extends AudienceAnalyticsSpec with EmbeddedSparkContext {
  def builder: Vector[Row] => AudienceAnalytics =
    log => new SparkAudienceAnalytics(
      new AggregateImpressionLog(sqlContext.createDataFrame(sc.parallelize(log), ImpressionLog.schema))
    )
}

class InMemoryAudienceAnalyticsSpec extends AudienceAnalyticsSpec {
  def builder: Vector[Row] => AudienceAnalytics =
    log => new InMemoryAudienceAnalytics(log.map(ImpressionLog.parse))
}

abstract class AudienceAnalyticsSpec extends FlatSpec with ShouldMatchers with DataGenerator {

  def builder: Vector[Row] => AudienceAnalytics

  private val impressions =
    repeat(100, impressionRow("bmw", "forbes.com", 10L, 1L, Array("income:50000", "education:high-school", "interest:technology"))) ++
    repeat(100, impressionRow("bmw", "forbes.com", 5L, 2L, Array("income:50000", "education:college", "interest:auto"))) ++
    repeat(100, impressionRow("bmw", "auto.com", 7L, 0L, Array("income:100000", "education:high-school", "interest:auto"))) ++
    repeat(100, impressionRow("audi", "cnn.com", 2L, 0L, Array("income:50000", "interest:audi", "education:high-school")))

  //private val impressionLog = impressions.map(ImpressionLog.parse)

  private val analytics = builder(impressions)

  "InMemoryAudienceAnalytics" should "compute audience estimate" in {
    val bmwEstimate = analytics.audienceEstimate(Vector("bmw"))
    assert(bmwEstimate.cookiesHLL.size() == 3 * 100)
    assert(bmwEstimate.impressions == 22 * 100)
    assert(bmwEstimate.clicks == 3 * 100)

    val forbesEstimate = analytics.audienceEstimate(sites = Vector("forbes.com"))
    assert(forbesEstimate.cookiesHLL.size() == 2 * 100)
    assert(forbesEstimate.impressions == 15 * 100)
    assert(forbesEstimate.clicks == 3 * 100)
  }

  it should "compute segment estimate" in {
    val fiftyK = analytics.segmentsEstimate(Vector("income:50000"))
    assert(fiftyK.cookiesHLL.size() == 3 * 100)
    assert(fiftyK.impressions == 17 * 100)
    assert(fiftyK.clicks == 3 * 100)

    val highSchool = analytics.segmentsEstimate(Vector("education:high-school"))
    assert(highSchool.cookiesHLL.size() == 3 * 100)
    assert(highSchool.impressions == 19 * 100)
    assert(highSchool.clicks == 1 * 100)
  }

  it should "compute audience intersection" in {
    val bmwAudience = analytics.audienceEstimate(Vector("bmw"))
    val intersection = analytics.segmentsIntersection(bmwAudience).toMap

    assert(intersection.size == 7)
    assert(intersection("interest:audi") == 0)
    intersection("income:50000") should (be >= 180L and be <= 2020L)
  }

} 
Example 169
Source File: HyperLogLogSpec.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics

import com.collective.analytics.schema.ImpressionLog
import com.collective.analytics.schema.RowSyntax._
import org.apache.spark.sql.hyperloglog.functions
import functions._
import org.apache.spark.sql.Row
import org.scalatest.FlatSpec

class HyperLogLogSpec  extends FlatSpec with EmbeddedSparkContext {

  private val impressions = Seq(
    Row("bmw", "forbes.com", "cookie#1", 10L, 1L, Array("income:50000", "education:high-school", "interest:technology")),
    Row("bmw", "forbes.com", "cookie#2", 5L, 2L, Array("income:150000", "education:college", "interest:auto")),
    Row("bmw", "auto.com", "cookie#3", 7L, 0L, Array("income:100000", "education:phd", "interest:music"))
  )

  private val impressionLog =
    sqlContext.createDataFrame(sc.parallelize(impressions), ImpressionLog.schema)

  "HyperLogLog" should "calculate correct column cardinalities" in {
    val cookiesHLL = impressionLog.select(hyperLogLog(ImpressionLog.cookie_id)).first().read[HLLColumn](0)
    assert(cookiesHLL.size() == 3)

    val sitesHLL = impressionLog.select(hyperLogLog(ImpressionLog.site_id)).first().read[HLLColumn](0)
    assert(sitesHLL.size() == 2)
  }

} 
Example 170
Source File: AggregateImpressionLogSpec.scala    From spark-hyperloglog   with MIT License 5 votes vote down vote up
package com.collective.analytics

import com.collective.analytics.schema.{SegmentLog, ActivityLog, ImpressionLog}
import org.apache.spark.sql.Row
import org.scalatest.FlatSpec

class AggregateImpressionLogSpec extends FlatSpec with EmbeddedSparkContext {

  private val impressions = Seq(
    Row("bmw", "forbes.com", "cookie#1", 10L, 1L, Array("income:50000", "education:high-school", "interest:technology")),
    Row("bmw", "forbes.com", "cookie#2", 5L, 2L, Array("income:50000", "education:college", "interest:auto")),
    Row("bmw", "auto.com", "cookie#3", 7L, 0L, Array("income:100000", "education:high-school", "interest:auto"))
  )

  private val impressionLog =
    sqlContext.createDataFrame(sc.parallelize(impressions), ImpressionLog.schema)

  private val aggregate = new AggregateImpressionLog(impressionLog)

  "AggregateImpressionLog" should "build activity log" in {
    val activityLog = aggregate.activityLog().collect().map(ActivityLog.parse)

    assert(activityLog.length == 2)

    val bmwAtForbes = activityLog.find(r => r.adId == "bmw" && r.siteId == "forbes.com").get
    assert(bmwAtForbes.cookiesHLL.size() == 2)
    assert(bmwAtForbes.impressions == 15)
    assert(bmwAtForbes.clicks == 3)

    val bmwAtAuto = activityLog.find(r => r.adId == "bmw" && r.siteId == "auto.com").get
    assert(bmwAtAuto.cookiesHLL.size() == 1)
    assert(bmwAtAuto.impressions == 7)
    assert(bmwAtAuto.clicks == 0)
  }

  it should "build segment log" in {
    val segmentLog = aggregate.segmentLog().collect().map(SegmentLog.parse)

    assert(segmentLog.length == 6)

    val income50k = segmentLog.find(_.segment == "income:50000").get
    assert(income50k.cookiesHLL.size() == 2)
    assert(income50k.impressions == 15)
    assert(income50k.clicks == 3)
  }

} 
Example 171
Source File: DatasourceRDD.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource.receiver

import org.apache.spark.partial.{BoundedDouble, CountEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.streaming.datasource.config.ParametersUtils
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetOperator}
import org.apache.spark.{Logging, Partition, TaskContext}

private[datasource]
class DatasourceRDD(
                     @transient sqlContext: SQLContext,
                     inputSentences: InputSentences,
                     datasourceParams: Map[String, String]
                   ) extends RDD[Row](sqlContext.sparkContext, Nil) with Logging with ParametersUtils {

  private var totalCalculated: Option[Long] = None

  private val InitTableName = "initTable"
  private val LimitedTableName = "limitedTable"
  private val TempInitQuery = s"select * from $InitTableName"

  val dataFrame = inputSentences.offsetConditions.fold(sqlContext.sql(inputSentences.query)) { case offset =>
    val parsedQuery = parseInitialQuery
    val conditionsSentence = offset.fromOffset.extractConditionSentence(parsedQuery)
    val orderSentence = offset.fromOffset.extractOrderSentence(parsedQuery, inverse = offset.limitRecords.isEmpty)
    val limitSentence = inputSentences.extractLimitSentence

    sqlContext.sql(parsedQuery + conditionsSentence + orderSentence + limitSentence)
  }

  private def parseInitialQuery: String = {
    if (inputSentences.query.toUpperCase.contains("WHERE") ||
      inputSentences.query.toUpperCase.contains("ORDER") ||
      inputSentences.query.toUpperCase.contains("LIMIT")
    ) {
      sqlContext.sql(inputSentences.query).registerTempTable(InitTableName)
      TempInitQuery
    } else inputSentences.query
  }

  def progressInputSentences: InputSentences = {
    if (!dataFrame.rdd.isEmpty()) {
      inputSentences.offsetConditions.fold(inputSentences) { case offset =>

        val offsetValue = if (offset.limitRecords.isEmpty)
          dataFrame.rdd.first().get(dataFrame.schema.fieldIndex(offset.fromOffset.name))
        else {
          dataFrame.registerTempTable(LimitedTableName)
          val limitedQuery = s"select * from $LimitedTableName order by ${offset.fromOffset.name} " +
            s"${OffsetOperator.toInverseOrderOperator(offset.fromOffset.operator)} limit 1"

          sqlContext.sql(limitedQuery).rdd.first().get(dataFrame.schema.fieldIndex(offset.fromOffset.name))
        }

        inputSentences.copy(offsetConditions = Option(offset.copy(fromOffset = offset.fromOffset.copy(
          value = Option(offsetValue),
          operator = OffsetOperator.toProgressOperator(offset.fromOffset.operator)))))
      }
    } else inputSentences
  }

  
  override def isEmpty(): Boolean = {
    totalCalculated.fold {
      withScope {
        partitions.length == 0 || take(1).length == 0
      }
    } { total => total == 0L }
  }

  override def getPartitions: Array[Partition] = dataFrame.rdd.partitions

  override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = dataFrame.rdd.compute(thePart, context)

  override def getPreferredLocations(thePart: Partition): Seq[String] = dataFrame.rdd.preferredLocations(thePart)
} 
Example 172
Source File: TemporalDataSuite.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.datasource.config.ConfigParameters._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.BeforeAndAfter

private[datasource] trait TemporalDataSuite extends DatasourceSuite
  with BeforeAndAfter {

  val conf = new SparkConf()
    .setAppName("datasource-receiver-example")
    .setIfMissing("spark.master", "local[*]")
  var sc: SparkContext = null
  var ssc: StreamingContext = null
  val tableName = "tableName"
  val datasourceParams = Map(
    StopGracefully -> "true",
    StopSparkContext -> "false",
    StorageLevelKey -> "MEMORY_ONLY",
    RememberDuration -> "15s"
  )
  val schema = new StructType(Array(
    StructField("id", StringType, nullable = true),
    StructField("idInt", IntegerType, nullable = true)
  ))
  val totalRegisters = 10000
  val registers = for (a <- 1 to totalRegisters) yield Row(a.toString, a)

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }
} 
Example 173
Source File: SparkEsDataFrameFunctions.scala    From Spark2Elasticsearch   with Apache License 2.0 5 votes vote down vote up
package com.github.jparkie.spark.elasticsearch.sql

import com.github.jparkie.spark.elasticsearch.SparkEsBulkWriter
import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsTransportClientConf, SparkEsWriteConf }
import com.github.jparkie.spark.elasticsearch.transport.SparkEsTransportClientManager
import org.apache.spark.sql.{ DataFrame, Row }


  def bulkLoadToEs(
    esIndex:                    String,
    esType:                     String,
    sparkEsTransportClientConf: SparkEsTransportClientConf = SparkEsTransportClientConf.fromSparkConf(sparkContext.getConf),
    sparkEsMapperConf:          SparkEsMapperConf          = SparkEsMapperConf.fromSparkConf(sparkContext.getConf),
    sparkEsWriteConf:           SparkEsWriteConf           = SparkEsWriteConf.fromSparkConf(sparkContext.getConf)
  )(implicit sparkEsTransportClientManager: SparkEsTransportClientManager = sparkEsTransportClientManager): Unit = {
    val sparkEsWriter = new SparkEsBulkWriter[Row](
      esIndex = esIndex,
      esType = esType,
      esClient = () => sparkEsTransportClientManager.getTransportClient(sparkEsTransportClientConf),
      sparkEsSerializer = new SparkEsDataFrameSerializer(dataFrame.schema),
      sparkEsMapper = new SparkEsDataFrameMapper(sparkEsMapperConf),
      sparkEsWriteConf = sparkEsWriteConf
    )

    sparkContext.runJob(dataFrame.rdd, sparkEsWriter.write _)
  }
} 
Example 174
Source File: SparkEsBulkWriterSpec.scala    From Spark2Elasticsearch   with Apache License 2.0 5 votes vote down vote up
package com.github.jparkie.spark.elasticsearch

import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsWriteConf }
import com.github.jparkie.spark.elasticsearch.sql.{ SparkEsDataFrameMapper, SparkEsDataFrameSerializer }
import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.{ Row, SQLContext }
import org.scalatest.{ MustMatchers, WordSpec }

class SparkEsBulkWriterSpec extends WordSpec with MustMatchers with SharedSparkContext {
  val esServer = new ElasticSearchServer()

  override def beforeAll(): Unit = {
    super.beforeAll()

    esServer.start()
  }

  override def afterAll(): Unit = {
    esServer.stop()

    super.afterAll()
  }

  "SparkEsBulkWriter" must {
    "execute write() successfully" in {
      esServer.createAndWaitForIndex("test_index")

      val sqlContext = new SQLContext(sc)

      val inputSparkEsWriteConf = SparkEsWriteConf(
        bulkActions = 10,
        bulkSizeInMB = 1,
        concurrentRequests = 0,
        flushTimeoutInSeconds = 1
      )
      val inputMapperConf = SparkEsMapperConf(
        esMappingId = Some("id"),
        esMappingParent = None,
        esMappingVersion = None,
        esMappingVersionType = None,
        esMappingRouting = None,
        esMappingTTLInMillis = None,
        esMappingTimestamp = None
      )
      val inputSchema = StructType(
        Array(
          StructField("id", StringType, true),
          StructField("parent", StringType, true),
          StructField("version", LongType, true),
          StructField("routing", StringType, true),
          StructField("ttl", LongType, true),
          StructField("timestamp", StringType, true),
          StructField("value", LongType, true)
        )
      )
      val inputData = sc.parallelize {
        Array(
          Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L),
          Row("TEST_ID_1", "TEST_PARENT_2", 2L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 2L),
          Row("TEST_ID_1", "TEST_PARENT_3", 3L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 3L),
          Row("TEST_ID_1", "TEST_PARENT_4", 4L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 4L),
          Row("TEST_ID_1", "TEST_PARENT_5", 5L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 5L),
          Row("TEST_ID_5", "TEST_PARENT_6", 6L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 6L),
          Row("TEST_ID_6", "TEST_PARENT_7", 7L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 7L),
          Row("TEST_ID_7", "TEST_PARENT_8", 8L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 8L),
          Row("TEST_ID_8", "TEST_PARENT_9", 9L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 9L),
          Row("TEST_ID_9", "TEST_PARENT_10", 10L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 10L),
          Row("TEST_ID_10", "TEST_PARENT_11", 11L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 11L)
        )
      }
      val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema)
      val inputDataIterator = inputDataFrame.rdd.toLocalIterator
      val inputSparkEsBulkWriter = new SparkEsBulkWriter[Row](
        esIndex = "test_index",
        esType = "test_type",
        esClient = () => esServer.client,
        sparkEsSerializer = new SparkEsDataFrameSerializer(inputSchema),
        sparkEsMapper = new SparkEsDataFrameMapper(inputMapperConf),
        sparkEsWriteConf = inputSparkEsWriteConf
      )

      inputSparkEsBulkWriter.write(null, inputDataIterator)

      val outputGetResponse = esServer.client.prepareGet("test_index", "test_type", "TEST_ID_1").get()

      outputGetResponse.isExists mustEqual true
      outputGetResponse.getSource.get("parent").asInstanceOf[String] mustEqual "TEST_PARENT_5"
      outputGetResponse.getSource.get("version").asInstanceOf[Integer] mustEqual 5
      outputGetResponse.getSource.get("routing").asInstanceOf[String] mustEqual "TEST_ROUTING_1"
      outputGetResponse.getSource.get("ttl").asInstanceOf[Integer] mustEqual 86400000
      outputGetResponse.getSource.get("timestamp").asInstanceOf[String] mustEqual "TEST_TIMESTAMP_1"
      outputGetResponse.getSource.get("value").asInstanceOf[Integer] mustEqual 5
    }
  }
} 
Example 175
Source File: UDFTest.scala    From SparkGIS   with Apache License 2.0 5 votes vote down vote up
package org.betterers.spark.gis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, Row}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.betterers.spark.gis.udf.Functions


class UDFTest extends FunSuite with BeforeAndAfter {
  import Geometry.WGS84

  val point = Geometry.point((2.0, 2.0))
  val multiPoint = Geometry.multiPoint((1.0, 1.0), (2.0, 2.0), (3.0, 3.0))
  var line = Geometry.line((11.0, 11.0), (12.0, 12.0))
  var multiLine = Geometry.multiLine(
    Seq((11.0, 1.0), (23.0, 23.0)),
    Seq((31.0, 3.0), (42.0, 42.0)))
  var polygon = Geometry.polygon((1.0, 1.0), (2.0, 2.0), (3.0, 1.0))
  var multiPolygon = Geometry.multiPolygon(
    Seq((1.0, 1.0), (2.0, 2.0), (3.0, 1.0)),
    Seq((1.1, 1.1), (2.0, 1.9), (2.5, 1.1))
  )
  val collection = Geometry.collection(point, multiPoint, line)
  val all: Seq[Geometry] = Seq(point, multiPoint, line, multiLine, polygon, multiPolygon, collection)

  var sc: SparkContext = _
  var sql: SQLContext = _

  before {
    sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("SparkGIS"))
    sql = new SQLContext(sc)
  }

  after {
    sc.stop()
  }

  test("ST_Boundary") {
    // all.foreach(g => println(Functions.ST_Boundary(g).toString))

    assertResult(true) {
      Functions.ST_Boundary(point).isEmpty
    }
    assertResult(true) {
      Functions.ST_Boundary(multiPoint).isEmpty
    }
    assertResult("Some(MULTIPOINT ((11 11), (12 12)))") {
      Functions.ST_Boundary(line).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiLine)
    }
    assertResult("Some(LINEARRING (1 1, 2 2, 3 1, 1 1))") {
      Functions.ST_Boundary(polygon).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiPolygon)
    }
    assertResult(None) {
      Functions.ST_Boundary(collection)
    }
  }

  test("ST_CoordDim") {
    all.foreach(g => {
      assertResult(3) {
        Functions.ST_CoordDim(g)
      }
    })
  }

  test("UDF in SQL") {
    val schema = StructType(Seq(
      StructField("id", IntegerType),
      StructField("geo", GeometryType.Instance)
    ))
    val jsons = Map(
      (1, "{\"type\":\"Point\",\"coordinates\":[1,1]}}"),
      (2, "{\"type\":\"LineString\",\"coordinates\":[[12,13],[15,20]]}}")
    )
    val rdd = sc.parallelize(Seq(
      "{\"id\":1,\"geo\":" + jsons(1) + "}",
      "{\"id\":2,\"geo\":" + jsons(2) + "}"
    ))
    rdd.name = "TEST"
    val df = sql.read.schema(schema).json(rdd)
    df.registerTempTable("TEST")
    Functions.register(sql)
    assertResult(Array(3,3)) {
      sql.sql("SELECT ST_CoordDim(geo) FROM TEST").collect().map(_.get(0))
    }
  }
} 
Example 176
Source File: NullValuesTest.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb

import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
import com.audienceproject.spark.dynamodb.implicits._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class NullValuesTest extends AbstractInMemoryTest {

    test("Insert nested StructType with null values") {
        dynamoDB.createTable(new CreateTableRequest()
            .withTableName("NullTest")
            .withAttributeDefinitions(new AttributeDefinition("name", "S"))
            .withKeySchema(new KeySchemaElement("name", "HASH"))
            .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))

        val schema = StructType(
            Seq(
                StructField("name", StringType, nullable = false),
                StructField("info", StructType(
                    Seq(
                        StructField("age", IntegerType, nullable = true),
                        StructField("address", StringType, nullable = true)
                    )
                ), nullable = true)
            )
        )

        val rows = spark.sparkContext.parallelize(Seq(
            Row("one", Row(30, "Somewhere")),
            Row("two", null),
            Row("three", Row(null, null))
        ))

        val newItemsDs = spark.createDataFrame(rows, schema)

        newItemsDs.write.dynamodb("NullTest")

        val validationDs = spark.read.dynamodb("NullTest")

        validationDs.show(false)
    }

} 
Example 177
Source File: MessageDelimiter.scala    From spark-cep   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.sources

import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, EmptyRow, Literal}
import org.apache.spark.sql.types.StructType

class MessageDelimiter extends MessageToRowConverter with Logging {
  val delimiter = " "

  def toRow(msg: String, schema: StructType): InternalRow = {
    val splitted = msg.split(delimiter).map(Literal(_))
    val casted = splitted.indices.map(i => Cast(splitted(i), schema(i).dataType).eval(EmptyRow))
    InternalRow.fromSeq(casted)
  }

  def toMessage(row: Row): String = row.mkString(delimiter)
}

trait MessageToRowConverter extends Serializable {
  def toRow(message: String, schema: StructType): InternalRow

  def toMessage(row: Row): String
} 
Example 178
Source File: JdbcUtil.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.Locale

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


object JdbcUtil {

  def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  def makeSetter(
    conn: Connection,
    dialect: JdbcDialect,
    dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val strValue = row.get(pos) match {
          case str: UTF8String => str.toString
          case str: String => str
        }
        stmt.setString(pos + 1, strValue)

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase(Locale.ROOT).split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
} 
Example 179
Source File: PredictNewsClassDemo.scala    From CkoocNLP   with Apache License 2.0 5 votes vote down vote up
package applications.mining

import algorithms.evaluation.MultiClassEvaluation
import config.paramconf.ClassParams
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.{Row, SparkSession}


object PredictNewsClassDemo extends Serializable {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkSession
      .builder
      .master("local[2]")
      .appName("predict news multi class demo")
      .getOrCreate()

    val args = Array("ckooc-ml/data/classnews/predict", "lr")
    val filePath = args(0)
    val modelType = args(1)

    var modelPath = ""
    val params = new ClassParams

    modelType match {
      case "lr" => modelPath = params.LRModelPath
      case "dt" => modelPath = params.DTModelPath
      case _ =>
        println("模型类型错误!")
        System.exit(1)
    }

    import spark.implicits._
    val data = spark.sparkContext.textFile(filePath).flatMap { line =>
      val tokens: Array[String] = line.split("\u00ef")
      if (tokens.length > 3) Some((tokens(0), tokens(1), tokens(2), tokens(3))) else None
    }.toDF("label", "title", "time", "content")
    data.persist()

    //加载模型,进行数据转换
    val model = PipelineModel.load(modelPath)
    val predictions = model.transform(data)

    //=== 模型评估
    val resultRDD = predictions.select("prediction", "indexedLabel").rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) }
    val (precision, recall, f1) = MultiClassEvaluation.multiClassEvaluate(resultRDD)
    println("\n\n========= 评估结果 ==========")
    println(s"\n加权准确率:$precision")
    println(s"加权召回率:$recall")
    println(s"F1值:$f1")

    //    predictions.select("label", "predictedLabel", "content").show(100, truncate = false)
    data.unpersist()

    spark.stop()
  }
} 
Example 180
Source File: StarsAnalysisDemo.scala    From CkoocNLP   with Apache License 2.0 5 votes vote down vote up
package applications.analysis

import java.io.{BufferedWriter, FileOutputStream, OutputStreamWriter}

import functions.segment.Segmenter
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


object StarsAnalysisDemo {
  def main(args: Array[String]) {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkSession
      .builder
      .master("local[2]")
      .appName("Stars Analysis Demo")
      .getOrCreate()

    val filePath = "E:/data/chinaNews/entertainment.txt"


    // 加载数据,并保留年份和内容字段,并对内容字段进行过滤
    import spark.implicits._
    val data = spark.sparkContext.textFile(filePath).flatMap { line =>
      val tokens: Array[String] = line.split("\u00ef")
      if (tokens.length > 3) {
        var year: String = tokens(2).split("-")(0)
        if (tokens(2).contains("年")) year = tokens(2).split("年")(0)

        var content = tokens(3)
        if (content.length > 22 && content.substring(0, 20).contains("日电")) {
          content = content.substring(content.indexOf("日电") + 2, content.length).trim
        }

        if (content.startsWith("(")) content = content.substring(content.indexOf(")") + 1, content.length)
        if (content.length > 20 && content.substring(content.length - 20, content.length).contains("记者")) {
          content = content.substring(0, content.lastIndexOf("记者")).trim
        }

        Some(year, content)
      } else None
    }.toDF("year", "content")

    // 分词,去除长度为1的词,每个词保留词性
    val segmenter = new Segmenter()
      .isAddNature(true)
      .isDelEn(true)
      .isDelNum(true)
      .setMinTermLen(2)
      .setMinTermNum(5)
      .setSegType("StandardSegment")
      .setInputCol("content")
      .setOutputCol("segmented")
    val segDF: DataFrame = segmenter.transform(data)
    segDF.cache()

    val segRDD: RDD[(Int, Seq[String])] = segDF.select("year", "segmented").rdd.map { case Row(year: String, terms: Seq[String]) =>
      (Integer.parseInt(year), terms)
    }

    val result: Array[String] = segRDD.map(line => line._1.toString + "\u00ef" + line._2.mkString(",")).collect()
    val writer: BufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream("E:/entertainment_seg.txt")))
    result.foreach(line => writer.write(line + "\n"))
    writer.close()

    // 统计2016出现在新闻中最多的明星
    val stars2016 = segRDD.filter(_._1 == 2016)
      .flatMap { case (year: Int, termStr: Seq[String]) =>
        val person = termStr
          .map(term => (term.split("/")(0), term.split("/")(1)))
          .filter(_._2.equalsIgnoreCase("nr"))
          .map(term => (term._1, 1L))

        person
      }
      .reduceByKey(_ + _)
      .sortBy(_._2, ascending = false)

    segDF.unpersist()

    stars2016.take(100).foreach(println)

    spark.stop()
  }
} 
Example 181
Source File: NetezzaRDD.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza

import java.sql.Connection
import java.util.Properties

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.{Partition, SparkContext, TaskContext}


  override def compute(thePart: Partition, context: TaskContext): Iterator[Row] =
    new Iterator[Row] {
      var closed = false
      var finished = false
      var gotNext = false
      var nextValue: Row = null

      context.addTaskCompletionListener { context => close() }
      val part = thePart.asInstanceOf[NetezzaPartition]
      val conn = getConnection()
      val reader = new NetezzaDataReader(conn, table, columns, filters, part, schema)
      reader.startExternalTableDataUnload()

      def getNext(): Row = {
        if (reader.hasNext) {
          reader.next()
        } else {
          finished = true
          null.asInstanceOf[Row]
        }
      }

      def close() {
        if (closed) return
        try {
          if (null != reader) {
            reader.close()
          }
        } catch {
          case e: Exception => logWarning("Exception closing Netezza record reader", e)
        }
        try {
          if (null != conn) {
            conn.close()
          }
          logInfo("closed connection")
        } catch {
          case e: Exception => logWarning("Exception closing connection", e)
        }
      }

      override def hasNext: Boolean = {
        if (!finished) {
          if (!gotNext) {
            nextValue = getNext()
            if (finished) {
              close()
            }
            gotNext = true
          }
        }
        !finished
      }

      override def next(): Row = {
        if (!hasNext) {
          throw new NoSuchElementException("End of stream")
        }
        gotNext = false
        nextValue
      }
    }
} 
Example 182
Source File: QueryTest.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza.integration

import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.{DataFrame, Row}
import org.scalatest.FunSuite

  def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
    val isSorted = df.queryExecution.logical.collect { case s: logical.Sort => s }.nonEmpty

    val sparkAnswer = try df.collect().toSeq catch {
      case e: Exception =>
        val errorMessage =
          s"""
             |Exception thrown while executing query:
             |${df.queryExecution}
             |== Exception ==
             |$e
             |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
          """.stripMargin
        return Some(errorMessage)
    }

    sameRows(expectedAnswer, sparkAnswer, isSorted).map { results =>
      s"""
         |Results do not match for query:
         |${df.queryExecution}
         |== Results ==
         |$results
       """.stripMargin
    }
  }

  def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
    // Converts data to types that we can do equality comparison using Scala collections.
    // For BigDecimal type, the Scala type has a better definition of equality test (similar to
    // Java's java.math.BigDecimal.compareTo).
    // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
    // equality test.
    val converted: Seq[Row] = answer.map(prepareRow)
    if (!isSorted) converted.sortBy(_.toString()) else converted
  }

  // We need to call prepareRow recursively to handle schemas with struct types.
  def prepareRow(row: Row): Row = {
    Row.fromSeq(row.toSeq.map {
      case null => null
      case d: java.math.BigDecimal => BigDecimal(d)
      // Convert array to Seq for easy equality check.
      case b: Array[_] => b.toSeq
      case r: Row => prepareRow(r)
      case o => o
    })
  }

  def sameRows(
                expectedAnswer: Seq[Row],
                sparkAnswer: Seq[Row],
                isSorted: Boolean = false): Option[String] = {
    if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) {
      val errorMessage =
        s"""
           |== Results ==
           |${sideBySide(
          s"== Correct Answer - ${expectedAnswer.size} ==" +:
            prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
          s"== Spark Answer - ${sparkAnswer.size} ==" +:
            prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")}
        """.stripMargin
      return Some(errorMessage)
    }
    None
  }

  def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
    val maxLeftSize = left.map(_.size).max
    val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
    val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")

    leftPadded.zip(rightPadded).map {
      case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r
    }
  }
} 
Example 183
Source File: TablePartitionColIntegrationTestSuite.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza.integration

import org.apache.spark.sql.{DataFrame, Row}
import org.netezza.error.NzSQLException


class TablePartitionColIntegrationTestSuite extends IntegrationSuiteBase with QueryTest {
  val tabName = "staff"
  val expected = Seq(
    Row(1, "John Doe"),
    Row(2, "Jeff Smith"),
    Row(3, "Kathy Saunders"),
    Row(4, null))

  val expectedFiltered = Seq(Row(1, "John Doe"), Row(2, "Jeff Smith"))


  override def beforeAll(): Unit = {
    super.beforeAll()
    try {executeJdbcStmt(s"drop table $tabName")} catch { case e: NzSQLException => }
    executeJdbcStmt(s"create table $tabName(id int , name varchar(20))")
    executeJdbcStmt(s"insert into $tabName values(1 , 'John Doe')")
    executeJdbcStmt(s"insert into $tabName values(2 , 'Jeff Smith')")
    executeJdbcStmt(s"insert into $tabName values(3 , 'Kathy Saunders')")
    executeJdbcStmt(s"insert into $tabName values(4 , null)")
  }

  override def afterAll(): Unit = {
    try {
      executeJdbcStmt(s"DROP TABLE $tabName")
    } finally {
      super.afterAll()
    }
  }

  private def defaultOpts() = {
    Map("url" -> testURL,
      "user" -> user,
      "password" -> password,
      "numPartitions" -> Integer.toString(1))
  }


  test("Test table read with column partitions") {
    val opts = defaultOpts +
      ("dbtable" -> s"$tabName") +
      ("partitioncol" -> "ID") +
      ("numPartitions" -> Integer.toString(4)) +
      ("lowerbound" -> "1") +
      ("upperbound" -> "100")

    val testDf = sqlContext.read.format("com.ibm.spark.netezza").options(opts).load()
    verifyAnswer(testDf, expected)
    verifyAnswer(testDf.filter("ID < 3"), expectedFiltered)
  }

  test("Test table read specifying lower or upper boundary") {
    var opts = defaultOpts +
      ("dbtable" -> s"$tabName") +
      ("partitioncol" -> "ID") +
      ("numPartitions" -> Integer.toString(4))

    val testOpts = Seq(opts , opts + ("lowerbound" -> "1"), opts + ("upperbound" -> "10"))
    for (opts <- testOpts) {
      val testDf = sqlContext.read.format("com.ibm.spark.netezza").options(opts).load()
      verifyAnswer(testDf, expected)
      verifyAnswer(testDf.filter("ID < 3"), expectedFiltered)
    }
  }

  test("Test table read with single partition") {
    val opts = defaultOpts +
      ("dbtable" -> s"$tabName") +
      ("partitioncol" -> "ID") +
      ("numPartitions" -> Integer.toString(1))

    val testDf = sqlContext.read.format("com.ibm.spark.netezza").options(opts).load()
    verifyAnswer(testDf, expected)
    verifyAnswer(testDf.filter("ID < 3"), expectedFiltered)
  }

  test("Test table with number of partitions set to zero.") {
    val opts = defaultOpts +
      ("dbtable" -> s"$tabName") +
      ("partitioncol" -> "ID") +
      ("numPartitions" -> Integer.toString(0))

    val testDf = sqlContext.read.format("com.ibm.spark.netezza").options(opts).load()
    verifyAnswer(testDf, expected)
  }
} 
Example 184
Source File: IntegrationSuiteBase.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza.integration

import java.sql.Connection

import com.ibm.spark.netezza.NetezzaJdbcUtils
import com.typesafe.config.ConfigFactory
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, DataFrame, SQLContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory

trait IntegrationSuiteBase extends FunSuite with BeforeAndAfterAll with QueryTest{
  private val log = LoggerFactory.getLogger(getClass)

  protected var sc: SparkContext = _
  protected var sqlContext: SQLContext = _
  protected var conn: Connection = _
  protected val prop = new java.util.Properties

  // Configurable vals
  protected var configFile = "application"
  protected var testURL: String = _
  protected var testTable: String = _
  protected var user: String = _
  protected var password: String = _
  protected var numPartitions: Int = _
  protected var sampleDbmaxNumTables: Int = _

  override def beforeAll(): Unit = {
    super.beforeAll()

    sc = new SparkContext("local[*]", "IntegrationTest", new SparkConf())
    sqlContext = new SQLContext(sc)

    val conf = ConfigFactory.load(configFile)
    testURL = conf.getString("test.integration.dbURL")
    testTable = conf.getString("test.integration.table")
    user = conf.getString("test.integration.user")
    password = conf.getString("test.integration.password")
    numPartitions = conf.getInt("test.integration.partition.number")
    sampleDbmaxNumTables = conf.getInt("test.integration.max.numtables")
    prop.setProperty("user", user)
    prop.setProperty("password", password)
    log.info("Attempting to get connection from" + testURL)
    conn = NetezzaJdbcUtils.getConnector(testURL, prop)()
    log.info("got connection.")
  }

  override def afterAll(): Unit = {
    try {
      sc.stop()
    }
    finally {
      conn.close()
      super.afterAll()
    }
  }

  
  def withTable(tableNames: String*)(f: => Unit): Unit = {
    try f finally {
      tableNames.foreach { name =>
        executeJdbcStmt(s"DROP TABLE $name")
      }
    }
  }
} 
Example 185
Source File: RangerShowTablesCommand.scala    From spark-ranger   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.ranger.authorization.spark.authorizer.{RangerSparkAuthorizer, SparkPrivilegeObject, SparkPrivilegeObjectType}
import org.apache.spark.sql.execution.command.{RunnableCommand, ShowTablesCommand}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute

case class RangerShowTablesCommand(child: ShowTablesCommand) extends RunnableCommand {

  override val output: Seq[Attribute] = child.output
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val rows = child.run(sparkSession)
    rows.filter(r => RangerSparkAuthorizer.isAllowed(toSparkPrivilegeObject(r)))
  }

  private def toSparkPrivilegeObject(row: Row): SparkPrivilegeObject = {
    val database = row.getString(0)
    val table = row.getString(1)
    new SparkPrivilegeObject(SparkPrivilegeObjectType.TABLE_OR_VIEW, database, table)
  }
} 
Example 186
Source File: RangerShowDatabasesCommand.scala    From spark-ranger   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.ranger.authorization.spark.authorizer.{RangerSparkAuthorizer, SparkPrivilegeObject, SparkPrivilegeObjectType}
import org.apache.spark.sql.execution.command.{RunnableCommand, ShowDatabasesCommand}
import org.apache.spark.sql.{Row, SparkSession}

case class RangerShowDatabasesCommand(child: ShowDatabasesCommand) extends RunnableCommand {
  override val output = child.output

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val rows = child.run(sparkSession)
    rows.filter(r => RangerSparkAuthorizer.isAllowed(toSparkPrivilegeObject(r)))
  }

  private def toSparkPrivilegeObject(row: Row): SparkPrivilegeObject = {
    val database = row.getString(0)
    new SparkPrivilegeObject(SparkPrivilegeObjectType.DATABASE, database, database)
  }


} 
Example 187
Source File: functions.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{array, col, explode, udf}
import org.apache.spark.sql.types.DataType

import scala.reflect.runtime.universe._

object functions {

  implicit class FilterAnnotations(dataset: DataFrame) {
    def filterByAnnotationsCol(column: String, function: Seq[Annotation] => Boolean): DataFrame = {
      val meta = dataset.schema(column).metadata
      val func = udf {
        annotatorProperties: Seq[Row] =>
          function(annotatorProperties.map(Annotation(_)))
      }
      dataset.filter(func(col(column)).as(column, meta))
    }
  }

  def mapAnnotations[T](function: Seq[Annotation] => T, outputType: DataType): UserDefinedFunction = udf ( {
    annotatorProperties: Seq[Row] =>
      function(annotatorProperties.map(Annotation(_)))
  }, outputType)

  def mapAnnotationsStrict(function: Seq[Annotation] => Seq[Annotation]): UserDefinedFunction = udf {
    annotatorProperties: Seq[Row] =>
      function(annotatorProperties.map(Annotation(_)))
  }

  implicit class MapAnnotations(dataset: DataFrame) {
    def mapAnnotationsCol[T: TypeTag](column: String, outputCol: String, function: Seq[Annotation] => T): DataFrame = {
      val meta = dataset.schema(column).metadata
      val func = udf {
        annotatorProperties: Seq[Row] =>
          function(annotatorProperties.map(Annotation(_)))
      }
      dataset.withColumn(outputCol, func(col(column)).as(outputCol, meta))
    }
  }

  implicit class EachAnnotations(dataset: DataFrame) {

    import dataset.sparkSession.implicits._

    def eachAnnotationsCol[T: TypeTag](column: String, function: Seq[Annotation] => Unit): Unit = {
      dataset.select(column).as[Array[Annotation]].foreach(function(_))
    }
  }

  implicit class ExplodeAnnotations(dataset: DataFrame) {
    def explodeAnnotationsCol[T: TypeTag](column: String, outputCol: String): DataFrame = {
      val meta = dataset.schema(column).metadata
      dataset.
        withColumn(outputCol, explode(col(column))).
        withColumn(outputCol, array(col(outputCol)).as(outputCol, meta))
    }
  }

} 
Example 188
Source File: PubTator.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.training

import com.johnsnowlabs.nlp.annotator.{PerceptronModel, SentenceDetector, Tokenizer}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, DocumentAssembler, Finisher}
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


object PubTator {

  def readDataset(spark: SparkSession, path: String): DataFrame = {
    val pubtator = spark.sparkContext.textFile(path)
    val titles = pubtator.filter(x => x.contains("|a|") | x.contains("|t|"))
    val titlesText = titles.map(x => x.split("\\|")).groupBy(_.head)
      .map(x => (x._1.toInt, x._2.foldLeft(Seq[String]())((a, b) => a ++ Seq(b.last)))).map(x => (x._1, x._2.mkString(" ")))
    val df = spark.createDataFrame(titlesText).toDF("doc_id", "text")
    val docAsm = new DocumentAssembler().setInputCol("text").setOutputCol("document")
    val setDet = new SentenceDetector().setInputCols("document").setOutputCol("sentence")
    val tknz = new Tokenizer().setInputCols("sentence").setOutputCol("token")
    val pl = new Pipeline().setStages(Array(docAsm, setDet, tknz))
    val nlpDf = pl.fit(df).transform(df)
    val annotations = pubtator.filter(x => !x.contains("|a|") & !x.contains("|t|") & x.nonEmpty)
    val splitAnnotations = annotations.map(_.split("\\t")).map(x => (x(0), x(1).toInt, x(2).toInt - 1, x(3), x(4), x(5)))
    val docAnnotations = splitAnnotations.groupBy(_._1).map(x => (x._1, x._2))
      .map(x =>
        (x._1.toInt,
          x._2.zipWithIndex.map(a => (new Annotation(AnnotatorType.CHUNK, a._1._2, a._1._3, a._1._4, Map("entity" -> a._1._5, "chunk" -> a._2.toString), Array[Float]()))).toList
        )
      )
    val chunkMeta = new MetadataBuilder().putString("annotatorType", AnnotatorType.CHUNK).build()
    val annDf = spark.createDataFrame(docAnnotations).toDF("doc_id", "chunk")
      .withColumn("chunk", col("chunk").as("chunk", chunkMeta))
    val alignedDf = nlpDf.join(annDf, Seq("doc_id")).selectExpr("doc_id", "sentence", "token", "chunk")
    val iobTagging = udf((tokens: Seq[Row], chunkLabels: Seq[Row]) => {
      val tokenAnnotations = tokens.map(Annotation(_))
      val labelAnnotations = chunkLabels.map(Annotation(_))
      tokenAnnotations.map(ta => {
        val tokenLabel = labelAnnotations.filter(la => la.begin <= ta.begin && la.end >= ta.end).headOption
        val tokenTag = {
          if (tokenLabel.isEmpty) "O"
          else {
            val tokenCSV = tokenLabel.get.metadata.get("entity").get
            if (tokenCSV == "UnknownType") "O"
            else {
              val tokenPrefix = if (ta.begin == tokenLabel.get.begin) "B-" else "I-"
              val paddedTokenTag = "T" + "%03d".format(tokenCSV.split(",")(0).slice(1, 4).toInt)
              tokenPrefix + paddedTokenTag
            }
          }
        }

        Annotation(AnnotatorType.NAMED_ENTITY,
          ta.begin, ta.end,
          tokenTag,
          Map("word" -> ta.result)
        )
      }
      )
    })
    val labelMeta = new MetadataBuilder().putString("annotatorType", AnnotatorType.NAMED_ENTITY).build()
    val taggedDf = alignedDf.withColumn("label", iobTagging(col("token"), col("chunk")).as("label", labelMeta))

    val pos = PerceptronModel.pretrained().setInputCols(Array("sentence", "token")).setOutputCol("pos")
    val finisher = new Finisher().setInputCols("token", "pos", "label").setIncludeMetadata(true)
    val finishingPipeline = new Pipeline().setStages(Array(pos, finisher))
    finishingPipeline.fit(taggedDf).transform(taggedDf)
      .withColumnRenamed("finished_label", "finished_ner") //CoNLL generator expects finished_ner
  }
} 
Example 189
Source File: DataBuilder.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.training.CoNLL
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._


object DataBuilder extends FlatSpec with BeforeAndAfterAll { this: Suite =>

  import SparkAccessor.spark.implicits._

  def basicDataBuild(content: String*)(implicit cleanupMode: String = "disabled"): Dataset[Row] = {
    val data = SparkAccessor.spark.sparkContext.parallelize(content).toDS().toDF("text")
    AnnotatorBuilder.withDocumentAssembler(data, cleanupMode)
  }

  def multipleDataBuild(content: Seq[String]): Dataset[Row] = {
    val data = SparkAccessor.spark.sparkContext.parallelize(content).toDS().toDF("text")
    AnnotatorBuilder.withDocumentAssembler(data)
  }

  def buildNerDataset(datasetContent: String): Dataset[Row] = {
    val lines = datasetContent.split("\n")
    val data = CoNLL(conllLabelIndex = 1)
      .readDatasetFromLines(lines, SparkAccessor.spark).toDF
    AnnotatorBuilder.withDocumentAssembler(data)
  }

  def loadParquetDataset(path: String) =
    SparkAccessor.spark.read.parquet(path)
} 
Example 190
Source File: DocumentAssemblerTestSpec.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp

import org.scalatest._
import org.apache.spark.sql.Row
import scala.language.reflectiveCalls
import Matchers._

class DocumentAssemblerTestSpec extends FlatSpec {
  def fixture = new {
    val text = ContentProvider.englishPhrase
    val df = AnnotatorBuilder.withDocumentAssembler(DataBuilder.basicDataBuild(text))
    val assembledDoc = df
      .select("document")
      .collect
      .flatMap { _.getSeq[Row](0) }
      .map { Annotation(_) }
  }

  "A DocumentAssembler" should "annotate with the correct indexes" in {
    val f = fixture
    f.text.head should equal (f.text(f.assembledDoc.head.begin))
    f.text.last should equal (f.text(f.assembledDoc.head.end))
  }
} 
Example 191
Source File: BigTextMatcherBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators.btm

import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._

trait BigTextMatcherBehaviors { this: FlatSpec =>

  def fullBigTextMatcher(dataset: => Dataset[Row]) {
    "An BigTextMatcher Annotator" should "successfully transform data" in {
      AnnotatorBuilder.withFullBigTextMatcher(dataset)
        .collect().foreach {
        row =>
          row.getSeq[Row](3)
            .map(Annotation(_))
            .foreach {
              case entity: Annotation if entity.annotatorType == "entity" =>
                println(entity, entity.end)
              case _ => ()
            }
      }
    }
  }
} 
Example 192
Source File: DependencyParserBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators.parser.dep

import com.johnsnowlabs.nlp._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.scalatest.FlatSpec
import com.johnsnowlabs.util.PipelineModels
import org.apache.spark.ml.Pipeline

trait DependencyParserBehaviors { this: FlatSpec =>


  def initialAnnotations(testDataSet: Dataset[Row]): Unit = {
    val fixture = createFixture(testDataSet)
    it should "add annotations" in {
      assert(fixture.dependencies.count > 0, "Annotations count should be greater than 0")
    }

    it should "add annotations with the correct annotationType" in {
      fixture.depAnnotations.foreach { a =>
        assert(a.annotatorType == AnnotatorType.DEPENDENCY, s"Annotation type should ${AnnotatorType.DEPENDENCY}")
      }
    }

    it should "annotate each token" in {
      assert(fixture.tokenAnnotations.size == fixture.depAnnotations.size, s"Every token should be annotated")
    }

    it should "annotate each word with a head" in {
      fixture.depAnnotations.foreach { a =>
        assert(a.result.nonEmpty, s"Result should have a head")
      }
    }

    it should "annotate each word with the correct indexes" in {
      fixture.depAnnotations
        .zip(fixture.tokenAnnotations)
        .foreach { case (dep, token) => assert(dep.begin == token.begin && dep.end == token.end, s"Token and word should have equal indixes") }
    }
  }

  private def createFixture(testDataSet: Dataset[Row]) = new {
    val dependencies: DataFrame = testDataSet.select("dependency")
    val depAnnotations: Seq[Annotation] = dependencies
      .collect
      .flatMap { r => r.getSeq[Row](0) }
      .map { r =>
        Annotation(r.getString(0), r.getInt(1), r.getInt(2), r.getString(3), r.getMap[String, String](4))
      }
    val tokens: DataFrame = testDataSet.select("token")
    val tokenAnnotations: Seq[Annotation] = tokens
      .collect
      .flatMap { r => r.getSeq[Row](0) }
      .map { r =>
        Annotation(r.getString(0), r.getInt(1), r.getInt(2), r.getString(3), r.getMap[String, String](4))
      }
  }

  def relationshipsBetweenWordsPredictor(testDataSet: Dataset[Row], pipeline: Pipeline): Unit = {

    val emptyDataSet = PipelineModels.dummyDataset

    val dependencyParserModel = pipeline.fit(emptyDataSet)

    it should "train a model" in {
      val model = dependencyParserModel.stages.last.asInstanceOf[DependencyParserModel]
      assert(model.isInstanceOf[DependencyParserModel])
    }

    val dependencyParserDataFrame = dependencyParserModel.transform(testDataSet)
    //dependencyParserDataFrame.collect()
    dependencyParserDataFrame.select("dependency").show(false)

    it should "predict relationships between words" in {
      assert(dependencyParserDataFrame.isInstanceOf[DataFrame])
    }

  }

} 
Example 193
Source File: CombinedTestSpec.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp._
import org.apache.spark.sql.Row
import org.scalatest._
import com.johnsnowlabs.nlp.AnnotatorType._

class CombinedTestSpec extends FlatSpec {

  "Simple combined annotators" should "successfully go through all transformations" in {
    val data = DataBuilder.basicDataBuild("This is my first sentence. This is your second list of words")
    val transformation = AnnotatorBuilder.withLemmaTaggedSentences(data)
    transformation
      .collect().foreach {
      row =>
        row.getSeq[Row](1).map(Annotation(_)).foreach { token =>
          // Document annotation
          assert(token.annotatorType == DOCUMENT)
        }
        row.getSeq[Row](2).map(Annotation(_)).foreach { token =>
          // SBD annotation
          assert(token.annotatorType == DOCUMENT)
        }
        row.getSeq[Row](4).map(Annotation(_)).foreach { token =>
          // POS annotation
          assert(token.annotatorType == POS)
        }
    }
  }
} 
Example 194
Source File: TokenizerBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, AnnotatorType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._

import scala.language.reflectiveCalls

trait TokenizerBehaviors { this: FlatSpec =>

  def fixture(dataset: => Dataset[Row]) = new {
    val df = AnnotatorBuilder.withTokenizer(AnnotatorBuilder.withTokenizer(dataset))
    val documents = df.select("document")
    val sentences = df.select("sentence")
    val tokens = df.select("token")
    val sentencesAnnotations = sentences
      .collect
      .flatMap { r => r.getSeq[Row](0) }
      .map { a => Annotation(a.getString(0), a.getInt(1), a.getInt(2), a.getString(3), a.getMap[String, String](4)) }
    val tokensAnnotations = tokens
      .collect
      .flatMap { r => r.getSeq[Row](0)}
      .map { a => Annotation(a.getString(0), a.getInt(1), a.getInt(2), a.getString(3), a.getMap[String, String](4)) }

    val docAnnotations = documents
      .collect
      .flatMap { r => r.getSeq[Row](0)}
      .map { a => Annotation(a.getString(0), a.getInt(1), a.getInt(2), a.getString(3), a.getMap[String, String](4)) }

    val corpus = docAnnotations
      .map(d => d.result)
      .mkString("")
  }

  def fullTokenizerPipeline(dataset: => Dataset[Row]) {
    "A Tokenizer Annotator" should "successfully transform data" in {
      val f = fixture(dataset)
      assert(f.tokensAnnotations.nonEmpty, "Tokenizer should add annotators")
    }

    it should "annotate using the annotatorType of token" in {
      val f = fixture(dataset)
      assert(f.tokensAnnotations.nonEmpty, "Tokenizer should add annotators")
      f.tokensAnnotations.foreach { a =>
        assert(a.annotatorType == AnnotatorType.TOKEN, "Tokenizer annotations type should be equal to 'token'")
      }
    }

    it should "annotate with the correct word indexes" in {
      val f = fixture(dataset)
      f.tokensAnnotations.foreach { a =>
        val token = a.result
        val sentenceToken = f.corpus.slice(a.begin, a.end + 1)
        assert(sentenceToken == token, s"Word ($sentenceToken) from sentence at (${a.begin},${a.end}) should be equal to token ($token) inside the corpus ${f.corpus}")
      }
    }
  }
} 
Example 195
Source File: TextMatcherTestSpec.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp.util.io.ReadAs
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._


class TextMatcherTestSpec extends FlatSpec with TextMatcherBehaviors {

  "An TextMatcher" should s"be of type $CHUNK" in {
    val entityExtractor = new TextMatcherModel
    assert(entityExtractor.outputAnnotatorType == CHUNK)
  }

  "A TextMatcher" should "extract entities with and without sentences" in {
    val dataset = DataBuilder.basicDataBuild("Hello dolore magna aliqua. Lorem ipsum dolor. sit in laborum")
    val result = AnnotatorBuilder.withFullTextMatcher(dataset)
    val resultNoSentence = AnnotatorBuilder.withFullTextMatcher(dataset, sbd = false)
    val resultNoSentenceNoCase = AnnotatorBuilder.withFullTextMatcher(dataset, sbd = false, caseSensitive = false)
    val extractedSentenced = Annotation.collect(result, "entity").flatten.toSeq
    val extractedNoSentence = Annotation.collect(resultNoSentence, "entity").flatten.toSeq
    val extractedNoSentenceNoCase = Annotation.collect(resultNoSentenceNoCase, "entity").flatten.toSeq

    val expectedSentenced = Seq(
      Annotation(CHUNK, 6, 24, "dolore magna aliqua", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "0")),
      Annotation(CHUNK, 53, 59, "laborum", Map("entity"->"entity", "sentence" -> "2", "chunk" -> "1"))
    )

    val expectedNoSentence = Seq(
      Annotation(CHUNK, 6, 24, "dolore magna aliqua", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "0")),
      Annotation(CHUNK, 53, 59, "laborum", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "1"))
    )

    val expectedNoSentenceNoCase = Seq(
      Annotation(CHUNK, 6, 24, "dolore magna aliqua", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "0")),
      Annotation(CHUNK, 27, 48, "Lorem ipsum dolor. sit", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "1")),
      Annotation(CHUNK, 53, 59, "laborum", Map("entity"->"entity", "sentence" -> "0", "chunk" -> "2"))
    )

    assert(extractedSentenced == expectedSentenced)
    assert(extractedNoSentence == expectedNoSentence)
    assert(extractedNoSentenceNoCase == expectedNoSentenceNoCase)
  }

  "An Entity Extractor" should "search inside sentences" in {
    val dataset = DataBuilder.basicDataBuild("Hello dolore magna. Aliqua")
    val result = AnnotatorBuilder.withFullTextMatcher(dataset, caseSensitive = false)
    val extracted = Annotation.collect(result, "entity").flatten.toSeq

    assert(extracted == Seq.empty[Annotation])
  }

  "A Recursive Pipeline TextMatcher" should "extract entities from dataset" in {
    val data = ContentProvider.parquetData.limit(1000)

    val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentenceDetector = new SentenceDetector()
      .setInputCols(Array("document"))
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    val entityExtractor = new TextMatcher()
      .setInputCols("sentence", "token")
      .setEntities("src/test/resources/entity-extractor/test-phrases.txt", ReadAs.TEXT)
      .setOutputCol("entity")

    val finisher = new Finisher()
      .setInputCols("entity")
      .setOutputAsArray(false)
      .setAnnotationSplitSymbol("@")
      .setValueSplitSymbol("#")

    val recursivePipeline = new RecursivePipeline()
      .setStages(Array(
        documentAssembler,
        sentenceDetector,
        tokenizer,
        entityExtractor,
        finisher
      ))

    recursivePipeline.fit(data).transform(data).show(false)
    assert(recursivePipeline.fit(data).transform(data).filter("finished_entity == ''").count > 0)
  }

  val latinBodyData: Dataset[Row] = DataBuilder.basicDataBuild(ContentProvider.latinBody)

  "A full Normalizer pipeline with latin content" should behave like fullTextMatcher(latinBodyData)

} 
Example 196
Source File: LemmatizerTestSpec.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._


class LemmatizerTestSpec extends FlatSpec with LemmatizerBehaviors {

  require(Some(SparkAccessor).isDefined)

  val lemmatizer = new Lemmatizer
  "a lemmatizer" should s"be of type ${AnnotatorType.TOKEN}" in {
    assert(lemmatizer.outputAnnotatorType == AnnotatorType.TOKEN)
  }

  val latinBodyData: Dataset[Row] = DataBuilder.basicDataBuild(ContentProvider.latinBody)

  "A full Normalizer pipeline with latin content" should behave like fullLemmatizerPipeline(latinBodyData)

  "A lemmatizer" should "be readable and writable" taggedAs Tag("LinuxOnly") in {
    val lemmatizer = new Lemmatizer().setDictionary("src/test/resources/lemma-corpus-small/lemmas_small.txt", "->", "\t")
    val path = "./test-output-tmp/lemmatizer"
    try {
      lemmatizer.write.overwrite.save(path)
      val lemmatizerRead = Lemmatizer.read.load(path)
      assert(lemmatizer.getDictionary.path == lemmatizerRead.getDictionary.path)
    } catch {
      case _: java.io.IOException => succeed
    }
  }

  "A lemmatizer" should "work under a pipeline framework" in {

    val data = ContentProvider.parquetData.limit(1000)

    val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentenceDetector = new SentenceDetector()
      .setInputCols(Array("document"))
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    val lemmatizer = new Lemmatizer()
      .setInputCols(Array("token"))
      .setOutputCol("lemma")
      .setDictionary("src/test/resources/lemma-corpus-small/lemmas_small.txt", "->", "\t")

    val finisher = new Finisher()
      .setInputCols("lemma")

    val pipeline = new Pipeline()
      .setStages(Array(
        documentAssembler,
        sentenceDetector,
        tokenizer,
        lemmatizer,
        finisher
      ))

    val recursivePipeline = new RecursivePipeline()
      .setStages(Array(
        documentAssembler,
        sentenceDetector,
        tokenizer,
        lemmatizer,
        finisher
      ))

    val model = pipeline.fit(data)
    model.transform(data).show()

    val PIPE_PATH = "./tmp_pipeline"

    model.write.overwrite().save(PIPE_PATH)
    val loadedPipeline = PipelineModel.read.load(PIPE_PATH)
    loadedPipeline.transform(data).show

    val recursiveModel = recursivePipeline.fit(data)
    recursiveModel.transform(data).show()

    recursiveModel.write.overwrite().save(PIPE_PATH)
    val loadedRecPipeline = PipelineModel.read.load(PIPE_PATH)
    loadedRecPipeline.transform(data).show

    succeed
  }

} 
Example 197
Source File: PragmaticSentimentBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators.sda.pragmatic

import com.johnsnowlabs.nlp.annotators.common.TokenizedSentence
import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._
import com.johnsnowlabs.nlp.AnnotatorType.SENTIMENT
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}

import scala.language.reflectiveCalls

trait PragmaticSentimentBehaviors { this: FlatSpec =>

  def fixture(dataset: Dataset[Row]) = new {
    val df = AnnotatorBuilder.withPragmaticSentimentDetector(dataset)
    val sdAnnotations = Annotation.collect(df, "sentiment").flatten
  }

  def isolatedSentimentDetector(tokenizedSentences: Array[TokenizedSentence], expectedScore: Double): Unit = {
    s"tagged sentences" should s"have an expected score of $expectedScore" in {
      val pragmaticScorer = new PragmaticScorer(ResourceHelper.parseKeyValueText(ExternalResource("src/test/resources/sentiment-corpus/default-sentiment-dict.txt", ReadAs.TEXT, Map("delimiter" -> ","))))
      val result = pragmaticScorer.score(tokenizedSentences)
      assert(result == expectedScore, s"because result: $result did not match expected: $expectedScore")
    }
  }

  def sparkBasedSentimentDetector(dataset: => Dataset[Row]): Unit = {

    "A Pragmatic Sentiment Analysis Annotator" should s"create annotations" in {
      val f = fixture(dataset)
      assert(f.sdAnnotations.size > 0)
    }

    it should "create annotations with the correct type" in {
      val f = fixture(dataset)
      f.sdAnnotations.foreach { a =>
        assert(a.annotatorType == SENTIMENT)
      }
    }

    it should "successfully score sentences" in {
      val f = fixture(dataset)
      f.sdAnnotations.foreach { a =>
        assert(List("positive", "negative").contains(a.result))
      }
    }
  }
} 
Example 198
Source File: TextMatcherBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._

trait TextMatcherBehaviors { this: FlatSpec =>

  def fullTextMatcher(dataset: => Dataset[Row]) {
    "An TextMatcher Annotator" should "successfully transform data" in {
      AnnotatorBuilder.withFullTextMatcher(dataset)
        .collect().foreach {
        row =>
          row.getSeq[Row](3)
            .map(Annotation(_))
            .foreach {
              case entity: Annotation if entity.annotatorType == "entity" =>
                println(entity, entity.end)
              case _ => ()
            }
      }
    }
  }
} 
Example 199
Source File: MultiDateMatcherBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.AnnotatorType.DATE
import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.Matchers._
import org.scalatest._

import scala.language.reflectiveCalls

trait MultiDateMatcherBehaviors extends FlatSpec {
  def fixture(dataset: Dataset[Row]) = new {
    val df = AnnotatorBuilder.withMultiDateMatcher(dataset)
    val dateAnnotations = df.select("date")
      .collect
      .flatMap { _.getSeq[Row](0) }
      .map { Annotation(_) }
  }

  def sparkBasedDateMatcher(dataset: => Dataset[Row]): Unit = {
    "A MultiDateMatcher Annotator" should s"successfuly parse dates" in {
      val f = fixture(dataset)
      f.dateAnnotations.foreach { a =>
        val d: String = a.result
        d should fullyMatch regex """\d+/\d+/\d+"""
      }
    }

    it should "create annotations" in {
      val f = fixture(dataset)
      assert(f.dateAnnotations.size > 0)
    }

    it should "create annotations with the correct type" in {
      val f = fixture(dataset)
      f.dateAnnotations.foreach { a =>
        assert(a.annotatorType == DATE)
      }
    }
  }
} 
Example 200
Source File: PragmaticDetectionBehaviors.scala    From spark-nlp   with Apache License 2.0 5 votes vote down vote up
package com.johnsnowlabs.nlp.annotators.sbd.pragmatic

import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, AnnotatorType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest._

import scala.language.reflectiveCalls

trait PragmaticDetectionBehaviors { this: FlatSpec =>

  def fixture(dataset: => Dataset[Row]) = new {
    val df = AnnotatorBuilder.withFullPragmaticSentenceDetector(dataset)
    val documents = df.select("document")
    val sentences = df.select("sentence")
    val sentencesAnnotations = sentences
      .collect
      .flatMap { r => r.getSeq[Row](0) }
      .map { a => Annotation(a.getString(0), a.getInt(1), a.getInt(2), a.getString(3), a.getMap[String, String](4)) }
    val corpus = sentencesAnnotations
      .flatMap { a => a.result }
      .mkString("")
  }

  private def f1Score(result: Array[String], expected: Array[String]): Double = {
    val nMatches = result.count(expected.contains(_))
    val nOutput = result.length
    val nExpected = expected.length
    val precision = nMatches / nOutput.toDouble
    val recall = nMatches / nExpected.toDouble
    (2 * precision * recall) / (precision + recall)
  }

  def isolatedPDReadAndMatchResult(input: String, correctAnswer: Array[String], customBounds: Array[String] = Array.empty[String]): Unit = {
    s"pragmatic boundaries detector with ${input.take(10)}...:" should
      s"successfully identify sentences as ${correctAnswer.take(1).take(10).mkString}..." in {
      val pragmaticApproach = new MixedPragmaticMethod(true, customBounds)
      val result = pragmaticApproach.extractBounds(input)
      val diffInResult = result.map(_.content).diff(correctAnswer)
      val diffInCorrect = correctAnswer.diff(result.map(_.content))
      assert(
        result.map(_.content).sameElements(correctAnswer),
        s"\n--------------\nSENTENCE IS WRONG:\n--------------\n$input" +
        s"\n--------------\nBECAUSE RESULT:\n--------------\n@@${diffInResult.mkString("\n@@")}" +
          s"\n--------------\nIS NOT EXPECTED:\n--------------\n@@${diffInCorrect.mkString("\n@@")}")
      assert(result.forall(sentence => {
        sentence.end == sentence.start + sentence.content.length - 1
      }), "because length mismatch")
    }
  }

  def isolatedPDReadAndMatchResultTag(input: String, correctAnswer: Array[String], customBounds: Array[String] = Array.empty[String], splitLength: Option[Int] = None): Unit = {
    s"pragmatic boundaries detector with ${input.take(10)}...:" should
      s"successfully identify sentences as ${correctAnswer.take(1).take(10).mkString}..." in {
      val sentenceDetector = new SentenceDetector()
      if (splitLength.isDefined)
        sentenceDetector.setSplitLength(splitLength.get)
      val result = sentenceDetector.tag(input).map(_.content)
      val diffInResult = result.diff(correctAnswer)
      val diffInCorrect = correctAnswer.diff(result)
      assert(
        result.sameElements(correctAnswer),
        s"\n--------------\nSENTENCE IS WRONG:\n--------------\n$input" +
          s"\n--------------\nBECAUSE RESULT:\n--------------\n@@${diffInResult.mkString("\n@@")}" +
          s"\n--------------\nIS NOT EXPECTED:\n--------------\n@@${diffInCorrect.mkString("\n@@")}")
    }
  }

  def isolatedPDReadScore(input: String, correctAnswer: Array[String], customBounds: Array[String] = Array.empty[String]): Unit = {
    s"boundaries prediction" should s"have an F1 score higher than 95%" in {
      val pragmaticApproach = new MixedPragmaticMethod(true, customBounds)
      val result = pragmaticApproach.extractBounds(input).map(_.content)
      val f1 = f1Score(result, correctAnswer)
      val unmatched = result.zip(correctAnswer).toMap.mapValues("\n"+_)
      info(s"F1 Score is: $f1")
      assert(f1 > 0.95, s"F1 Score is below 95%.\nMatch sentences:\n${unmatched.mkString("\n")}")
    }
  }

  def sparkBasedSentenceDetector(dataset: => Dataset[Row]): Unit = {
    "a Pragmatic Sentence Detection Annotator" should s"successfully annotate documents" in {
      val f = fixture(dataset)
      assert(f.sentencesAnnotations.nonEmpty, "Annotations should exists")
    }

    it should "add annotators of type sbd" in {
      val f = fixture(dataset)
      f.sentencesAnnotations.foreach { a =>
        assert(a.annotatorType == AnnotatorType.DOCUMENT, "annotatorType should sbd")
      }
    }
  }
}