org.apache.spark.sql.types.StringType Scala Examples

The following examples show how to use org.apache.spark.sql.types.StringType. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TextFileFormat.scala    From drizzle-spark   with Apache License 2.0 12 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration


  def getCompressionExtension(context: TaskAttemptContext): String = {
    // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
    if (FileOutputFormat.getCompressOutput(context)) {
      val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
      ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
    } else {
      ""
    }
  }
} 
Example 2
Source File: BooleanStatement.scala    From spark-snowflake   with Apache License 2.0 7 votes vote down vote up
package net.snowflake.spark.snowflake.pushdowns.querygeneration

import net.snowflake.spark.snowflake.{ConstantString, SnowflakeSQLStatement}
import org.apache.spark.sql.catalyst.expressions.{
  Attribute,
  Contains,
  EndsWith,
  EqualTo,
  Expression,
  GreaterThan,
  GreaterThanOrEqual,
  In,
  IsNotNull,
  IsNull,
  LessThan,
  LessThanOrEqual,
  Literal,
  Not,
  StartsWith
}
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String


private[querygeneration] object BooleanStatement {
  def unapply(
    expAttr: (Expression, Seq[Attribute])
  ): Option[SnowflakeSQLStatement] = {
    val expr = expAttr._1
    val fields = expAttr._2

    Option(expr match {
      case In(child, list) if list.forall(_.isInstanceOf[Literal]) =>
        convertStatement(child, fields) + "IN" +
          blockStatement(convertStatements(fields, list: _*))
      case IsNull(child) =>
        blockStatement(convertStatement(child, fields) + "IS NULL")
      case IsNotNull(child) =>
        blockStatement(convertStatement(child, fields) + "IS NOT NULL")
      case Not(child) => {
        child match {
          case EqualTo(left, right) =>
            blockStatement(
              convertStatement(left, fields) + "!=" +
                convertStatement(right, fields)
            )
          case GreaterThanOrEqual(left, right) =>
            convertStatement(LessThan(left, right), fields)
          case LessThanOrEqual(left, right) =>
            convertStatement(GreaterThan(left, right), fields)
          case GreaterThan(left, right) =>
            convertStatement(LessThanOrEqual(left, right), fields)
          case LessThan(left, right) =>
            convertStatement(GreaterThanOrEqual(left, right), fields)
          case _ =>
            ConstantString("NOT") +
              blockStatement(convertStatement(child, fields))
        }
      }
      case Contains(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}%'"
      case EndsWith(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'%${pattern.toString}'"
      case StartsWith(child, Literal(pattern: UTF8String, StringType)) =>
        convertStatement(child, fields) + "LIKE" + s"'${pattern.toString}%'"

      case _ => null
    })
  }
} 
Example 3
Source File: OnErrorSuite.scala    From spark-snowflake   with Apache License 2.0 6 votes vote down vote up
package net.snowflake.spark.snowflake

import net.snowflake.client.jdbc.SnowflakeSQLException
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class OnErrorSuite extends IntegrationSuiteBase {
  lazy val table = s"spark_test_table_$randomSuffix"

  lazy val schema = new StructType(
    Array(StructField("var", StringType, nullable = false))
  )

  lazy val df: DataFrame = sparkSession.createDataFrame(
    sc.parallelize(
      Seq(Row("{\"dsadas\nadsa\":12311}"), Row("{\"abc\":334}")) // invalid json key
    ),
    schema
  )

  override def beforeAll(): Unit = {
    super.beforeAll()
    jdbcUpdate(s"create or replace table $table(var variant)")
  }

  override def afterAll(): Unit = {
    jdbcUpdate(s"drop table $table")
    super.afterAll()
  }

  test("continue_on_error off") {

    assertThrows[SnowflakeSQLException] {
      df.write
        .format(SNOWFLAKE_SOURCE_NAME)
        .options(connectorOptionsNoTable)
        .option("dbtable", table)
        .mode(SaveMode.Append)
        .save()
    }
  }

  test("continue_on_error on") {
    df.write
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("continue_on_error", "on")
      .option("dbtable", table)
      .mode(SaveMode.Append)
      .save()

    val result = sparkSession.read
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("dbtable", table)
      .load()

    assert(result.collect().length == 1)
  }

} 
Example 4
Source File: SparkLensTest.scala    From spark-tools   with Apache License 2.0 5 votes vote down vote up
package io.univalence

import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import io.univalence.SparkLens._
import org.scalatest.FunSuite

case class Toto(name: String, age: Int)

case class Tata(toto: Toto)

class SparkLensTest extends FunSuite {

  val conf: SparkConf = new SparkConf()
  conf.setAppName("yo")
  conf.setMaster("local[*]")

  implicit val ss: SparkSession = SparkSession.builder.config(conf).getOrCreate

  import ss.implicits._

  test("testLensRegExp change string") {
    assert(lensRegExp(ss.createDataFrame(Seq(Toto("a", 1))))({
      case ("name", StringType) => true
      case _                    => false
    }, { case (a: String, d)    => a.toUpperCase }).as[Toto].first() == Toto("A", 1))
  }

  test("change Int") {
    assert(lensRegExp(ss.createDataFrame(Seq(Tata(Toto("a", 1)))))({
      case ("toto/age", _) => true
      case _               => false
    }, { case (a: Int, d)  => a + 1 }).as[Tata].first() == Tata(Toto("a", 2)))
  }

  ignore("null to nil") {

    val df: DataFrame = ss.read.parquet("/home/phong/daily_gpp_20180705")

    val yoho: DataFrame = lensRegExp(df)({
      case (_, ArrayType(_, _)) => true
      case _                    => false
    }, (a, b) => if (a == null) Nil else a)

  }

} 
Example 5
Source File: ConfigurableDataGeneratorMain.scala    From Spark.TableStatsExample   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sa.examples.tablestats

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{StringType, LongType, StructField, StructType}
import org.apache.spark.{SparkContext, SparkConf}

import scala.collection.mutable
import scala.util.Random



object ConfigurableDataGeneratorMain {
  def main(args: Array[String]): Unit = {

    if (args.length == 0) {
      println("ConfigurableDataGeneratorMain <outputPath> <numberOfColumns> <numberOfRecords> <numberOfPartitions> <local>")
      return
    }

    val outputPath = args(0)
    val numberOfColumns = args(1).toInt
    val numberOfRecords = args(2).toInt
    val numberOfPartitions = args(3).toInt
    val runLocal = (args.length == 5 && args(4).equals("L"))

    var sc: SparkContext = null
    if (runLocal) {
      val sparkConfig = new SparkConf()
      sparkConfig.set("spark.broadcast.compress", "false")
      sparkConfig.set("spark.shuffle.compress", "false")
      sparkConfig.set("spark.shuffle.spill.compress", "false")
      sc = new SparkContext("local", "test", sparkConfig)
    } else {
      val sparkConfig = new SparkConf().setAppName("ConfigurableDataGeneratorMain")
      sc = new SparkContext(sparkConfig)
    }

    val sqlContext = new org.apache.spark.sql.SQLContext(sc)

    //Part A
    val rowRDD = sc.parallelize( (0 until numberOfPartitions).map( i => i), numberOfPartitions)

    //Part B
    val megaDataRDD = rowRDD.flatMap( r => {
      val random = new Random()

      val dataRange = (0 until numberOfRecords/numberOfPartitions).iterator
      dataRange.map[Row]( x => {
        val values = new mutable.ArrayBuffer[Any]
        for (i <- 0 until numberOfColumns) {
          if (i % 2 == 0) {
            values.+=(random.nextInt(100).toLong)
          } else {
            values.+=(random.nextInt(100).toString)
          }
        }
        new GenericRow(values.toArray)
      })
    })

    //Part C
    val schema =
      StructType(
        (0 until numberOfColumns).map( i => {
          if (i % 2 == 0) {
            StructField("longColumn_" + i, LongType, true) }
          else {
            StructField("stringColumn_" + i, StringType, true)
          }
        })
      )
    val df = sqlContext.createDataFrame(megaDataRDD, schema)
    df.saveAsParquetFile(outputPath)

    //Part D
    sc.stop()
  }
} 
Example 6
Source File: TestTableStatsSinglePathMain.scala    From Spark.TableStatsExample   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sa.examples.tablestats


import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, LongType, StructField, StructType}
import org.scalatest.{FunSuite, BeforeAndAfterEach, BeforeAndAfterAll}


class TestTableStatsSinglePathMain extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll{
  test("run table stats on sample data") {

    val sparkConfig = new SparkConf()
    sparkConfig.set("spark.broadcast.compress", "false")
    sparkConfig.set("spark.shuffle.compress", "false")
    sparkConfig.set("spark.shuffle.spill.compress", "false")
    var sc = new SparkContext("local", "test", sparkConfig)
    try {
      val sqlContext = new org.apache.spark.sql.SQLContext(sc)

      val schema =
        StructType(
          Array(
            StructField("id", LongType, true),
            StructField("name", StringType, true),
            StructField("age", LongType, true),
            StructField("gender", StringType, true),
            StructField("height", LongType, true),
            StructField("job_title", StringType, true)
          )
        )

      val rowRDD = sc.parallelize(Array(
        Row(1l, "Name.1", 20l, "M", 6l, "dad"),
        Row(2l, "Name.2", 20l, "F", 5l, "mom"),
        Row(3l, "Name.3", 20l, "F", 5l, "mom"),
        Row(4l, "Name.4", 20l, "M", 5l, "mom"),
        Row(5l, "Name.5", 10l, "M", 4l, "kid"),
        Row(6l, "Name.6", 8l, "M", 3l, "kid")))

      val df = sqlContext.createDataFrame(rowRDD, schema)

      val firstPassStats = TableStatsSinglePathMain.getFirstPassStat(df)

      assertResult(6l)(firstPassStats.columnStatsMap(0).maxLong)
      assertResult(1l)(firstPassStats.columnStatsMap(0).minLong)
      assertResult(21l)(firstPassStats.columnStatsMap(0).sumLong)
      assertResult(3l)(firstPassStats.columnStatsMap(0).avgLong)

      assertResult(2)(firstPassStats.columnStatsMap(3).topNValues.topNCountsForColumnArray.length)

      firstPassStats.columnStatsMap(3).topNValues.topNCountsForColumnArray.foreach { r =>
        if (r._1.equals("M")) {
          assertResult(4l)(r._2)
        } else if (r._1.equals("F")) {
          assertResult(2l)(r._2)
        } else {
          throw new RuntimeException("Unknown gender: " + r._1)
        }
      }
    } finally {
      sc.stop()
    }
  }
} 
Example 7
Source File: Tokenizer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  @Since("1.6.0")
  def getToLowercase: Boolean = $(toLowercase)

  setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)

  override protected def createTransformFunc: String => Seq[String] = { originStr =>
    val re = $(pattern).r
    val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
    val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
    val minLength = $(minTokenLength)
    tokens.filter(_.length >= minLength)
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType, s"Input type must be string type but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, true)

  @Since("1.4.1")
  override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}

@Since("1.6.0")
object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] {

  @Since("1.6.0")
  override def load(path: String): RegexTokenizer = super.load(path)
} 
Example 8
Source File: NGram.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  @Since("1.5.0")
  def getN: Int = $(n)

  setDefault(n -> 2)

  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

@Since("1.6.0")
object NGram extends DefaultParamsReadable[NGram] {

  @Since("1.6.0")
  override def load(path: String): NGram = super.load(path)
} 
Example 9
Source File: InputFileName.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileNameHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the current file being read if available",
  extended = "> SELECT _FUNC_();\n ''")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = true

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initInternal(): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileNameHolder.getInputFileName()
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
      "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false")
  }
} 
Example 10
Source File: MapDataSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 11
Source File: ScalaUDFSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 12
Source File: CallMethodViaReflectionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}


class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)
  }

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)
  }

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))
  }

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("input type checking") {
    assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)
  }

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
  }

  private def createExpr(className: String, methodName: String, args: Any*) = {
    CallMethodViaReflection(
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
      args.map(Literal.apply)
    )
  }
} 
Example 13
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 14
Source File: resources.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 15
Source File: WholeStageCodegenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 16
Source File: GroupedIteratorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 17
Source File: DDLSourceLoadSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 18
Source File: CarbonDataFrameExample.scala    From CarbonDataLearning   with GNU General Public License v3.0 5 votes vote down vote up
package org.github.xubo245.carbonDataLearning.example

import org.apache.carbondata.examples.util.ExampleUtils
import org.apache.spark.sql.{SaveMode, SparkSession}

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createCarbonSession("CarbonDataFrameExample")
    exampleBody(spark)
    spark.close()
  }

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
    df.write
      .format("carbondata")
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names
      .mode(SaveMode.Overwrite)
      .save()

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf = spark.read
      .format("carbondata")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")
      .load()

    df.write
      .format("csv")
      .option("tableName", "csv_df_table")
      .option("partitionColumns", "c1") // a list of column names
      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .mode(SaveMode.Overwrite)
      .csv("/Users/xubo/Desktop/xubo/git/carbondata3/examples/spark2/target/csv/1.csv")

    // Reads carbondata to dataframe
    val carbondf2 = spark.read
      .format("csv")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "csv_df_table")

      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .load("/Users/xubo/Desktop/xubo/git/carbondata3/examples/spark2/target/csv")

    carbondf2.show()


    // Dataframe operations
    carbondf.printSchema()
    carbondf.select($"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
  }
} 
Example 19
Source File: TestSFObjectWriter.scala    From spark-salesforce   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.salesforce

import org.mockito.Mockito._
import org.mockito.Matchers._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{ FunSuite, BeforeAndAfterEach}
import com.springml.salesforce.wave.api.BulkAPI
import org.apache.spark.{ SparkConf, SparkContext}
import com.springml.salesforce.wave.model.{ JobInfo, BatchInfo}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ Row, DataFrame, SQLContext}
import org.apache.spark.sql.types.{ StructType, StringType, StructField}

class TestSFObjectWriter extends FunSuite with MockitoSugar with BeforeAndAfterEach {
  val contact = "Contact";
  val jobId = "750B0000000WlhtIAC";
  val batchId = "751B0000000scSHIAY";
  val data = "Id,Description\n003B00000067Rnx,123456\n003B00000067Rnw,7890";

  val bulkAPI = mock[BulkAPI](withSettings().serializable())
  val writer = mock[SFObjectWriter]

  var sparkConf: SparkConf = _
  var sc: SparkContext = _

  override def beforeEach() {
    val jobInfo = new JobInfo
    jobInfo.setId(jobId)
    when(bulkAPI.createJob(contact)).thenReturn(jobInfo)

    val batchInfo = new BatchInfo
    batchInfo.setId(batchId)
    batchInfo.setJobId(jobId)
    when(bulkAPI.addBatch(jobId, data)).thenReturn(batchInfo)

    when(bulkAPI.closeJob(jobId)).thenReturn(jobInfo)
    when(bulkAPI.isCompleted(jobId)).thenReturn(true)

    sparkConf = new SparkConf().setMaster("local").setAppName("Test SF Object Update")
    sc = new SparkContext(sparkConf)
  }

  private def sampleDF() : DataFrame = {
    val rowArray = new Array[Row](2)
    val fieldArray = new Array[String](2)

    fieldArray(0) = "003B00000067Rnx"
    fieldArray(1) = "Desc1"
    rowArray(0) = Row.fromSeq(fieldArray)

    val fieldArray1 = new Array[String](2)
    fieldArray1(0) = "001B00000067Rnx"
    fieldArray1(1) = "Desc2"
    rowArray(1) = Row.fromSeq(fieldArray1)

    val rdd = sc.parallelize(rowArray)
    val schema = StructType(
      StructField("id", StringType, true) ::
      StructField("desc", StringType, true) :: Nil)

    val sqlContext = new SQLContext(sc)
    sqlContext.createDataFrame(rdd, schema)
  }

  test ("Write Object to Salesforce") {
    val df = sampleDF();
    val csvHeader = Utils.csvHeadder(df.schema)
    writer.writeData(df.rdd)
    sc.stop()
  }
} 
Example 20
Source File: TestMetadataConstructor.scala    From spark-salesforce   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.salesforce.metadata

import org.apache.spark.sql.types.{StructType, StringType, IntegerType, LongType,
  FloatType, DateType, TimestampType, BooleanType, StructField}
import org.scalatest.FunSuite
import com.springml.spark.salesforce.Utils


class TestMetadataConstructor extends FunSuite {

  test("Test Metadata generation") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val columnStruct = columnNames.map(colName => StructField(colName, StringType, true))
    val schema = StructType(columnStruct)

    val schemaString = MetadataConstructor.generateMetaString(schema,"sampleDataSet", Utils.metadataConfig(null))
    assert(schemaString.length > 0)
    assert(schemaString.contains("sampleDataSet"))
  }

  test("Test Metadata generation With Custom MetadataConfig") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val intField = StructField("intCol", IntegerType, true)
    val longField = StructField("longCol", LongType, true)
    val floatField = StructField("floatCol", FloatType, true)
    val dateField = StructField("dateCol", DateType, true)
    val timestampField = StructField("timestampCol", TimestampType, true)
    val stringField = StructField("stringCol", StringType, true)
    val someTypeField = StructField("someTypeCol", BooleanType, true)

    val columnStruct = Array[StructField] (intField, longField, floatField, dateField, timestampField, stringField, someTypeField)

    val schema = StructType(columnStruct)

    var metadataConfig = Map("string" -> Map("wave_type" -> "Text"))
    metadataConfig += ("integer" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "0", "defaultValue" -> "100"))
    metadataConfig += ("float" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "2"))
    metadataConfig += ("long" -> Map("wave_type" -> "Numeric", "precision" -> "18", "scale" -> "0"))
    metadataConfig += ("date" -> Map("wave_type" -> "Date", "format" -> "yyyy/MM/dd"))
    metadataConfig += ("timestamp" -> Map("wave_type" -> "Date", "format" -> "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"))


    val schemaString = MetadataConstructor.generateMetaString(schema, "sampleDataSet", metadataConfig)
    assert(schemaString.length > 0)
    assert(schemaString.contains("sampleDataSet"))
    assert(schemaString.contains("Numeric"))
    assert(schemaString.contains("precision"))
    assert(schemaString.contains("scale"))
    assert(schemaString.contains("18"))
    assert(schemaString.contains("Text"))
    assert(schemaString.contains("Date"))
    assert(schemaString.contains("format"))
    assert(schemaString.contains("defaultValue"))
    assert(schemaString.contains("100"))
    assert(schemaString.contains("yyyy/MM/dd"))
    assert(schemaString.contains("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"))
  }
} 
Example 21
Source File: SparkScoreDoc.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.models

import org.apache.lucene.document.Document
import org.apache.lucene.index.IndexableField
import org.apache.lucene.search.{IndexSearcher, ScoreDoc}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.inferNumericType
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.{DocIdField, ScoreField, ShardField}

import scala.collection.JavaConverters._

sealed trait FieldType extends Serializable
object TextType extends FieldType
object IntType extends FieldType
object DoubleType extends FieldType
object LongType extends FieldType
object FloatType extends FieldType



  private def inferNumericType(num: Number): FieldType = {
    num match {
      case _: java.lang.Double => DoubleType
      case _: java.lang.Long => LongType
      case _: java.lang.Integer => IntType
      case _: java.lang.Float => FloatType
      case _ => TextType
    }
  }
} 
Example 22
Source File: Schema.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.s2jobs

import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

object Schema {
  
  val GraphElementSchema = StructType(CommonFields ++ Seq(
    StructField("id", StringType, nullable = true),
    StructField("service", StringType, nullable = true),
    StructField("column", StringType, nullable = true),
    StructField("from", StringType, nullable = true),
    StructField("to", StringType, nullable = true),
    StructField("label", StringType, nullable = true),
    StructField("props", StringType, nullable = true)
  ))
} 
Example 23
Source File: ExecutePython.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.script

import java.util
import java.util.UUID

import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import cn.piflow.conf.{ConfigurableStop, Port, StopGroup}
import cn.piflow.util.FileUtil
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
import jep.Jep
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.JavaConversions._


class ExecutePython extends ConfigurableStop{
  override val authorEmail: String = "[email protected]"
  override val description: String = "Execute python script"
  override val inportList: List[String] = List(Port.DefaultPort)
  override val outportList: List[String] = List(Port.DefaultPort)

  var script : String = _

  override def setProperties(map: Map[String, Any]): Unit = {
    script = MapUtil.get(map,"script").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val script = new PropertyDescriptor()
      .name("script")
      .displayName("script")
      .description("The code of python")
      .defaultValue("")
      .required(true)

    descriptor = script :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/script/python.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.ScriptGroup)
  }
  override def initialize(ctx: ProcessContext): Unit = {}

  override def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {

    val jep = new Jep()
    val scriptPath = "/tmp/pythonExcutor-"+ UUID.randomUUID() +".py"
    FileUtil.writeFile(script,scriptPath)
    jep.runScript(scriptPath)
  }
} 
Example 24
Source File: ExecutePythonWithDataFrame.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.script

import java.util
import java.util.UUID

import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import cn.piflow.conf.{ConfigurableStop, Port, StopGroup}
import cn.piflow.util.FileUtil
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
import jep.Jep
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.JavaConversions._


class ExecutePythonWithDataFrame extends ConfigurableStop{
  override val authorEmail: String = "[email protected]"
  override val description: String = "Execute python script with dataframe"
  override val inportList: List[String] = List(Port.DefaultPort)
  override val outportList: List[String] = List(Port.DefaultPort)

  var script : String = _
  var execFunction : String = _

  override def setProperties(map: Map[String, Any]): Unit = {
    script = MapUtil.get(map,"script").asInstanceOf[String]
    execFunction = MapUtil.get(map,"execFunction").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val script = new PropertyDescriptor()
      .name("script")
      .displayName("script")
      .description("The code of python")
      .defaultValue("")
      .required(true)
    val execFunction = new PropertyDescriptor()
      .name("execFunction")
      .displayName("execFunction")
      .description("The function of python script to be executed.")
      .defaultValue("")
      .required(true)
    descriptor = script :: descriptor
    descriptor = execFunction :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/script/python.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.ScriptGroup)
  }
  override def initialize(ctx: ProcessContext): Unit = {}

  override def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {

    val spark = pec.get[SparkSession]()

    val df = in.read()

    val jep = new Jep()
    val scriptPath = "/tmp/pythonExcutor-"+ UUID.randomUUID() +".py"
    FileUtil.writeFile(script,scriptPath)
    jep.runScript(scriptPath)


    val listInfo = df.toJSON.collectAsList()
    jep.eval(s"result = $execFunction($listInfo)")
    val resultArrayList = jep.getValue("result",new util.ArrayList().getClass)
    println(resultArrayList)


    var resultList = List[Map[String, Any]]()
    val it = resultArrayList.iterator()
    while(it.hasNext){
      val i = it.next().asInstanceOf[java.util.HashMap[String, Any]]
      val item =  mapAsScalaMap(i).toMap[String, Any]
      resultList =  item +: resultList
    }


    val rows = resultList.map( m => Row(m.values.toSeq:_*))
    val header = resultList.head.keys.toList
    val schema = StructType(header.map(fieldName => new StructField(fieldName, StringType, true)))

    val rdd = spark.sparkContext.parallelize(rows)
    val resultDF = spark.createDataFrame(rdd, schema)

    out.write(resultDF)
  }
} 
Example 25
Source File: DataFrameRowParser.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.script

import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import cn.piflow.conf._
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

import scala.beans.BeanProperty

class DataFrameRowParser extends ConfigurableStop{

  val authorEmail: String = "[email protected]"
  val description: String = "Create dataframe by schema"
  val inportList: List[String] = List(Port.DefaultPort.toString)
  val outportList: List[String] = List(Port.DefaultPort.toString)

  var schema: String = _
  var separator: String = _

  override def setProperties(map: Map[String, Any]): Unit = {
    schema = MapUtil.get(map,"schema").asInstanceOf[String]
    separator = MapUtil.get(map,"separator").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val schema = new PropertyDescriptor().name("schema").displayName("schema").description("The schema of dataframe").defaultValue("").required(true)
    val separator = new PropertyDescriptor().name("separator").displayName("separator").description("The separator of schema").defaultValue("").required(true)
    descriptor = schema :: descriptor
    descriptor = separator :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/script/DataFrameRowParser.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.ScriptGroup.toString)
  }


  override def initialize(ctx: ProcessContext): Unit = {}

  override def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {
    val spark = pec.get[SparkSession]()
    val inDF = in.read()

    //parse RDD
    val rdd = inDF.rdd.map(row => {
      val fieldArray = row.get(0).asInstanceOf[String].split(",")
      Row.fromSeq(fieldArray.toSeq)
    })

    //parse schema
    val field = schema.split(separator)
    val structFieldArray : Array[StructField] = new Array[StructField](field.size)
    for(i <- 0 to field.size - 1){
      structFieldArray(i) = new StructField(field(i),StringType, nullable = true)
    }
    val schemaStructType = StructType(structFieldArray)

    //create DataFrame
    val df = spark.createDataFrame(rdd,schemaStructType)
    //df.show()
    out.write(df)
  }

} 
Example 26
Source File: WordSpliter.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.nlp

import cn.piflow._
import cn.piflow.conf._
import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import com.huaban.analysis.jieba.JiebaSegmenter.SegMode
import com.huaban.analysis.jieba._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class WordSpliter extends ConfigurableStop {

  val authorEmail: String = "[email protected]"
  val description: String = "Word segmentation"
  val inportList: List[String] = List(Port.AnyPort.toString)
  val outportList: List[String] = List(Port.DefaultPort.toString)

  var path:String = _


  val jiebaSegmenter = new JiebaSegmenter()
  var tokenARR:ArrayBuffer[String]=ArrayBuffer()

  def segmenter(str:String): Unit ={

    var strVar = str
    //delete symbol
    strVar = strVar.replaceAll( "[\\p{P}+~$`^=|<>~`$^+=|<>¥×+\\s]" , "");

    val tokens = jiebaSegmenter.process(strVar,SegMode.SEARCH).asScala

    for (token: SegToken <- tokens){

        tokenARR += token.word

    }
  }

  def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {

    val session: SparkSession = pec.get[SparkSession]()

    //read
    val strDF = session.read.text(path)

    //segmenter
    segmenter(strDF.head().getString(0))

    //write df
    val rows: List[Row] = tokenARR.map(each => {
      var arr:Array[String]=Array(each)
      val row: Row = Row.fromSeq(arr)
      row
    }).toList
    val rowRDD: RDD[Row] = session.sparkContext.makeRDD(rows)
    val schema: StructType = StructType(Array(
      StructField("words",StringType)
    ))
    val df: DataFrame = session.createDataFrame(rowRDD,schema)

    out.write(df)
  }

  def initialize(ctx: ProcessContext): Unit = {

  }

  def setProperties(map : Map[String, Any]) = {
    path = MapUtil.get(map,"path").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val path = new PropertyDescriptor().name("path").displayName("path").description("The path of text file").defaultValue("").required(true)
    descriptor = path :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/nlp/NLP.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.Alg_NLPGroup.toString)
  }

} 
Example 27
Source File: BasicDataSourceSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BasicDataSourceSuite extends BaseDataSourceTest("test_datasource_basic") {
  private val row1 = Row(null, "Hello")
  private val row2 = Row(2, "TiDB")
  private val row3 = Row(3, "Spark")
  private val row4 = Row(4, null)

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  override def beforeAll(): Unit = {
    super.beforeAll()

    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello'), (2, 'TiDB')")
  }

  test("Test Select") {
    if (!supportBatchWrite) {
      cancel
    }

    testTiDBSelect(Seq(row1, row2))
  }

  test("Test Write Append") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    df.write
      .format("tidb")
      .options(tidbOptions)
      .option("database", database)
      .option("table", table)
      .mode("append")
      .save()

    testTiDBSelect(Seq(row1, row2, row3, row4))
  }

  test("Test Write Overwrite") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    val caught = intercept[TiBatchWriteException] {
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("database", database)
        .option("table", table)
        .mode("overwrite")
        .save()
    }

    assert(
      caught.getMessage
        .equals("SaveMode: Overwrite is not supported. TiSpark only support SaveMode.Append."))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 28
Source File: CheckUnsupportedSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CheckUnsupportedSuite extends BaseDataSourceTest("test_datasource_check_unsupported") {

  override def beforeAll(): Unit =
    super.beforeAll()

  test("Test write to partition table") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()

    tidbStmt.execute("set @@tidb_enable_table_partition = 1")

    jdbcUpdate(
      s"create table $dbtable(i int, s varchar(128)) partition by range(i) (partition p0 values less than maxvalue)")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello')")

    val row1 = Row(null, "Hello")
    val row2 = Row(2, "TiDB")
    val row3 = Row(3, "Spark")

    val schema = StructType(List(StructField("i", IntegerType), StructField("s", StringType)))

    {
      val caught = intercept[TiBatchWriteException] {
        tidbWrite(List(row2, row3), schema)
      }
      assert(
        caught.getMessage
          .equals("tispark currently does not support write data to partition table!"))
    }

    testTiDBSelect(Seq(row1))
  }

  test("Check Virtual Generated Column") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2))")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
      List(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))

    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
    }
    assert(
      caught.getMessage
        .equals("tispark currently does not support write data to table with generated column!"))

  }

  test("Check Stored Generated Column") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2) STORED)")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
      List(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))
    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
    }
    assert(
      caught.getMessage
        .equals("tispark currently does not support write data to table with generated column!"))

  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 29
Source File: TiSparkTypeSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

class TiSparkTypeSuite extends BaseDataSourceTest("type_test") {
  private val row1 = Row(null, "Hello")
  private val row2 = Row(2L, "TiDB")
  private val row3 = Row(3L, "Spark")
  private val row5 = Row(Long.MaxValue, "Duplicate")

  private val schema = StructType(List(StructField("i", LongType), StructField("s", StringType)))
  test("bigint test") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i bigint, s varchar(128))")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello'), (2, 'TiDB')")

    tidbWrite(List(row3, row5), schema)
    testTiDBSelect(List(row1, row2, row3, row5))
  }
} 
Example 30
Source File: MissingParameterSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class MissingParameterSuite extends BaseDataSourceTest("test_datasource_missing_parameter") {
  private val row1 = Row(null, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  test("Missing parameter: database") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")

    val caught = intercept[IllegalArgumentException] {
      val rows = row1 :: Nil
      val data: RDD[Row] = sc.makeRDD(rows)
      val df = sqlContext.createDataFrame(data, schema)
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("table", table)
        .mode("append")
        .save()
    }
    assert(
      caught.getMessage
        .equals("requirement failed: Option 'database' is required."))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 31
Source File: BatchWriteIssueSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BatchWriteIssueSuite extends BaseDataSourceTest("test_batchwrite_issue") {
  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  test("Combine unique index with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), CONSTRAINT ab UNIQUE (a, b))")
  }

  test("Combine primary key with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a, b))")
  }

  test("PK is handler with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a))")
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }

  private def doTestNullValues(createTableSQL: String): Unit = {
    if (!supportBatchWrite) {
      cancel
    }
    val schema = StructType(
      List(
        StructField("a", IntegerType),
        StructField("b", StringType),
        StructField("c", StringType)))

    val options = Some(Map("replace" -> "true"))

    dropTable()
    jdbcUpdate(createTableSQL)
    jdbcUpdate(s"alter table $dbtable add column to_delete int")
    jdbcUpdate(s"alter table $dbtable add column c varchar(64) default 'c33'")
    jdbcUpdate(s"alter table $dbtable drop column to_delete")
    jdbcUpdate(s"""
                  |insert into $dbtable values(11, 'c12', null);
                  |insert into $dbtable values(21, 'c22', null);
                  |insert into $dbtable (a, b) values(31, 'c32');
                  |insert into $dbtable values(41, 'c42', 'c43');
                  |
      """.stripMargin)

    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head == null)
    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
    assert(
      queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("c33"))
    assert(
      queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))

    {
      val row1 = Row(11, "c12", "c13")
      val row3 = Row(31, "c32", null)

      tidbWrite(List(row1, row3), schema, options)

      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c13"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head == null)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))
    }

    {
      val row1 = Row(11, "c12", "c213")
      val row3 = Row(31, "c32", "tt")
      tidbWrite(List(row1, row3), schema, options)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c213"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("tt"))
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))
    }
  }
} 
Example 32
Source File: LockTimeoutSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.ttl

import com.pingcap.tikv.TTLManager
import com.pingcap.tikv.exception.GrpcException
import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class LockTimeoutSuite extends BaseDataSourceTest("test_lock_timeout") {
  private val row1 = Row(1, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  override def beforeAll(): Unit = {
    super.beforeAll()
    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
  }

  test("Test Lock TTL Timeout") {
    if (!supportTTLUpdate) {
      cancel
    }

    val seconds = 1000
    val sleep1 = TTLManager.MANAGED_LOCK_TTL + 10 * seconds
    val sleep2 = TTLManager.MANAGED_LOCK_TTL + 15 * seconds

    val data: RDD[Row] = sc.makeRDD(List(row1))
    val df = sqlContext.createDataFrame(data, schema)

    new Thread(new Runnable {
      override def run(): Unit = {
        Thread.sleep(sleep1)
        queryTiDBViaJDBC(s"select * from $dbtable")
      }
    }).start()

    val grpcException = intercept[GrpcException] {
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("database", database)
        .option("table", table)
        .option("sleepAfterPrewritePrimaryKey", sleep2)
        .mode("append")
        .save()
    }

    assert(grpcException.getMessage.equals("retry is exhausted."))
    assert(grpcException.getCause.getMessage.startsWith("Txn commit primary key failed"))
    assert(
      grpcException.getCause.getCause.getMessage.startsWith(
        "Key exception occurred and the reason is retryable: \"Txn(Mvcc(TxnLockNotFound"))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 33
Source File: RedisStreamProvider.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import com.redislabs.provider.redis.util.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}


class RedisStreamProvider extends DataSourceRegister with StreamSourceProvider with Logging {

  override def shortName(): String = "redis"

  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
                            providerName: String, parameters: Map[String, String]): (String, StructType) = {
    providerName -> schema.getOrElse {
      StructType(Seq(StructField("_id", StringType)))
    }
  }

  override def createSource(sqlContext: SQLContext, metadataPath: String,
                            schema: Option[StructType], providerName: String,
                            parameters: Map[String, String]): Source = {
    val (_, ss) = sourceSchema(sqlContext, schema, providerName, parameters)
    val source = new RedisSource(sqlContext, metadataPath, Some(ss), parameters)
    source.start()
    source
  }
} 
Example 34
Source File: XSDToSchemaSuite.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml.util

import java.nio.file.Paths

import org.apache.spark.sql.types.{ArrayType, StructField, StructType, StringType}
import org.scalatest.funsuite.AnyFunSuite

class XSDToSchemaSuite extends AnyFunSuite {

  test("Basic parsing") {
    val parsedSchema = XSDToSchema.read(Paths.get("src/test/resources/basket.xsd"))
    val expectedSchema = StructType(Array(
      StructField("basket", StructType(Array(
        StructField("entry", ArrayType(
          StructType(Array(
            StructField("key", StringType),
            StructField("value", StringType)
          )))
        ))
      )))
    )
    assert(expectedSchema === parsedSchema)
  }

} 
Example 35
Source File: SparkExecuteStatementOperationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
  }

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
    assert(columns.get(1).getComment() == "")
  }
} 
Example 36
Source File: datasources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.xsql.XSQLSessionCatalog

case class XSQLShowDatasourcesCommand(datasourcePattern: Option[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("dataSourceName", StringType, nullable = false)() :: Nil
  }

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    val datasources =
      datasourcePattern
        .map { pattern =>
          catalog.listDatasources(pattern)
        }
        .getOrElse(catalog.listDatasources())
    datasources.map { d =>
      Row(d)
    }
  }
}

case class XSQLAddDatasourceCommand(dataSourceName: String, properties: Map[String, String])
  extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.addDataSource(dataSourceName, properties)
    Seq.empty[Row]
  }
}

case class XSQLRemoveDatasourceCommand(dataSourceName: String, ifExists: Boolean)
  extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.removeDataSource(dataSourceName, ifExists)
    Seq.empty[Row]
  }
}

case class XSQLRefreshDatasourceCommand(dataSourceName: String) extends RunnableCommand {
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    catalog.refreshDataSource(dataSourceName)
    Seq.empty[Row]
  }
} 
Example 37
Source File: databases.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.xsql.XSQLSessionCatalog


case class XSQLSetDatabaseCommand(dataSourceName: Option[String], databaseName: String)
  extends RunnableCommand {

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val catalog = sparkSession.sessionState.catalog.asInstanceOf[XSQLSessionCatalog]
    if (dataSourceName.isEmpty) {
      catalog.setCurrentDatabase(databaseName)
    } else {
      catalog.setCurrentDatabase(dataSourceName.get, databaseName)
    }
    Seq.empty[Row]
  }
} 
Example 38
Source File: inputFileBlock.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = false

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileBlockHolder.getInputFilePath
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
      isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.")
case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_start"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getStartOffset
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.")
case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_length"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getLength
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
  }
} 
Example 39
Source File: StatsEstimationTestBase.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4
    case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize)
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 40
Source File: ScalaUDFSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil,
      true :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvaluationWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 41
Source File: CallMethodViaReflectionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}


class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)
  }

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)
  }

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))
  }

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("input type checking") {
    assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)
  }

  test("unsupported type checking") {
    val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("arguments from the third require boolean, byte, short"))
  }

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
  }

  private def createExpr(className: String, methodName: String, args: Any*) = {
    CallMethodViaReflection(
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
      args.map(Literal.apply)
    )
  }
} 
Example 42
Source File: LikeSimplificationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, StringType}

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("null pattern") {
    val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
  }
} 
Example 43
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 44
Source File: resources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 45
Source File: GroupedIteratorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 46
Source File: ShowTablesUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.sources.DatasourceCatalog
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.{StringType, StructField, StructType}


private[sql]
case class ShowTablesUsingCommand(provider: String, options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("IS_TEMPORARY", StringType, nullable = false) ::
    StructField("KIND", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val dataSource: Any = DatasourceResolver.resolverFor(sqlContext).newInstanceOf(provider)

    dataSource match {
      case describableRelation: DatasourceCatalog =>
        describableRelation
          .getRelations(sqlContext, new CaseInsensitiveMap(options))
          .map(relationInfo => Row(
            relationInfo.name,
            relationInfo.isTemporary.toString.toUpperCase,
            relationInfo.kind.toUpperCase))
      case _ =>
        throw new RuntimeException(s"The provided data source $provider does not support " +
        "showing its relations.")
    }
  }
} 
Example 47
Source File: DescribeTableUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.TableIdentifierUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.{DatasourceCatalog, RelationInfo}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}


private[sql]
case class DescribeTableUsingCommand(
    name: TableIdentifier,
    provider: String,
    options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("DDL_STMT", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    // Convert the table name according to the case-sensitivity settings
    val tableId = name.toSeq
    val resolver = DatasourceResolver.resolverFor(sqlContext)
    val catalog = resolver.newInstanceOfTyped[DatasourceCatalog](provider)

    Seq(catalog
      .getRelation(sqlContext, tableId, new CaseInsensitiveMap(options)) match {
        case None => Row("", "")
        case Some(RelationInfo(relName, _, _, ddl, _)) => Row(
          relName, ddl.getOrElse(""))
    })
  }
} 
Example 48
Source File: DescCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.commands.hive

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{Row, SQLContext}


case class DescCommand(ident: TableIdentifier) extends HiveRunnableCommand {

  override protected val commandName: String = s"DESC $ident"

  override def execute(sqlContext: SQLContext): Seq[Row] = {
    val plan = sqlContext.catalog.lookupRelation(ident)
    if (plan.resolved) {
      plan.schema.map { field =>
        Row(field.name, field.dataType.simpleString, None)
      }
    } else {
      Seq.empty
    }
  }

  override lazy val output: Seq[Attribute] =
    AttributeReference("col_name", StringType)() ::
    AttributeReference("data_type", StringType)() ::
    AttributeReference("comment", StringType)() :: Nil
} 
Example 49
Source File: dependenciesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableDependencyCalculator
import org.apache.spark.sql.sources.{RelationKind, Table}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object DependenciesSystemTableProvider extends SystemTableProvider with LocalSpark {
  
  override def execute(): Seq[Row] = {
    val tables = getTables(sqlContext.catalog)
    val dependentsMap = buildDependentsMap(tables)

    def kindOf(tableIdentifier: TableIdentifier): String =
      tables
        .get(tableIdentifier)
        .map(plan => RelationKind.kindOf(plan).getOrElse(Table).name)
        .getOrElse(DependenciesSystemTable.UnknownType)
        .toUpperCase

    dependentsMap.flatMap {
      case (tableIdent, dependents) =>
        val curKind = kindOf(tableIdent)
        dependents.map { dependent =>
          val dependentKind = kindOf(dependent)
          Row(
            tableIdent.database.orNull,
            tableIdent.table,
            curKind,
            dependent.database.orNull,
            dependent.table,
            dependentKind,
            ReferenceDependency.id)
        }
    }.toSeq
  }

  override val schema: StructType = DependenciesSystemTable.schema
}

object DependenciesSystemTable extends SchemaEnumeration {
  val baseSchemaName = Field("BASE_SCHEMA_NAME", StringType, nullable = true)
  val baseObjectName = Field("BASE_OBJECT_NAME", StringType, nullable = false)
  val baseObjectType = Field("BASE_OBJECT_TYPE", StringType, nullable = false)
  val dependentSchemaName = Field("DEPENDENT_SCHEMA_NAME", StringType, nullable = true)
  val dependentObjectName = Field("DEPENDENT_OBJECT_NAME", StringType, nullable = false)
  val dependentObjectType = Field("DEPENDENT_OBJECT_TYPE", StringType, nullable = false)
  val dependencyType = Field("DEPENDENCY_TYPE", IntegerType, nullable = false)

  private[DependenciesSystemTable] val UnknownType = "UNKNOWN"
} 
Example 50
Source File: partitionFunctionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.GenericUtil._


  private def typeNameOf(f: PartitionFunction): String = f match {
    case _: RangePartitionFunction => "RANGE"
    case _: BlockPartitionFunction => "BLOCK"
    case _: HashPartitionFunction => "HASH"
  }
}

object PartitionFunctionSystemTable extends SchemaEnumeration {
  val id = Field("ID", StringType, nullable = false)
  val functionType = Field("TYPE", StringType, nullable = false)
  val columnName = Field("COLUMN_NAME", StringType, nullable = false)
  val columnType = Field("COLUMN_TYPE", StringType, nullable = false)
  val boundaries = Field("BOUNDARIES", StringType, nullable = true)
  val block = Field("BLOCK_SIZE", IntegerType, nullable = true)
  val partitions = Field("PARTITIONS", IntegerType, nullable = true)
  val minP = Field("MIN_PARTITIONS", IntegerType, nullable = true)
  val maxP = Field("MAX_PARTITIONS", IntegerType, nullable = true)
} 
Example 51
Source File: sessionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLConf, SQLContext}


  private def allSettingsOf(conf: SQLConf): Map[String, String] = {
    val setConfs = conf.getAllConfs
    val defaultConfs = conf.getAllDefinedConfs.collect {
      case (key, default, _) if !setConfs.contains(key) => key -> default
    }
    setConfs ++ defaultConfs
  }

  override def schema: StructType = SessionSystemTable.schema
}

object SessionSystemTable extends SchemaEnumeration {
  val section = Field("SECTION", StringType, nullable = false)
  val key = Field("KEY", StringType, nullable = false)
  val value = Field("VALUE", StringType, nullable = true)
} 
Example 52
Source File: tablesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.commands.WithOrigin
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.util.CollectionUtils.CaseInsensitiveMap
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._

object TablesSystemTableProvider extends SystemTableProvider with LocalSpark with ProviderBound {
  
  override def buildScan(requiredColumns: Array[String],
                         filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[DatasourceCatalog](provider) match {
      case catalog: DatasourceCatalog with DatasourceCatalogPushDown =>
        catalog.getRelations(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog: DatasourceCatalog =>
        val values =
          catalog
            .getRelations(sqlContext, new CaseInsensitiveMap(options))
            .map(relationInfo => Row(
              relationInfo.name,
              relationInfo.isTemporary.toString.toUpperCase,
              relationInfo.kind.toUpperCase,
              relationInfo.provider))
        val rows = schema.buildPrunedFilteredScan(requiredColumns, filters)(values)
        sparkContext.parallelize(rows)
    }
}

sealed trait TablesSystemTable extends SystemTable {
  override def schema: StructType = TablesSystemTable.schema
}

object TablesSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val isTemporary = Field("IS_TEMPORARY", StringType, nullable = false)
  val kind = Field("KIND", StringType, nullable = false)
  val provider = Field("PROVIDER", StringType, nullable = true)
} 
Example 53
Source File: metadataSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._


  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[MetadataCatalog](provider) match {
      case catalog: MetadataCatalog with MetadataCatalogPushDown =>
        catalog.getTableMetadata(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog =>
        val rows = catalog.getTableMetadata(sqlContext, options).flatMap { tableMetadata =>
          val formatter = new OutputFormatter(tableMetadata.tableName, tableMetadata.metadata)
          formatter.format().map(Row.fromSeq)
        }
        sparkContext.parallelize(schema.buildPrunedFilteredScan(requiredColumns, filters)(rows))
    }

  override def schema: StructType = MetadataSystemTable.schema
}

object MetadataSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val metadataKey = Field("METADATA_KEY", StringType, nullable = true)
  val metadataValue = Field("METADATA_VALUE", StringType, nullable = true)
} 
Example 54
Source File: relationMappingSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object RelationMappingSystemTableProvider extends SystemTableProvider with LocalSpark {

  
  override def execute(): Seq[Row] = {
    sqlContext.tableNames().map { tableName =>
      val plan = sqlContext.catalog.lookupRelation(TableIdentifier(tableName))
      val sqlName = plan.collectFirst {
        case s: SqlLikeRelation =>
          s.relationName
        case LogicalRelation(s: SqlLikeRelation, _) =>
          s.relationName
      }
      Row(tableName, sqlName)
    }
  }
}

object RelationMappingSystemTable extends SchemaEnumeration {
  val sparkName = Field("RELATION_NAME", StringType, nullable = false)
  val providerName = Field("SQL_NAME", StringType, nullable = true)
} 
Example 55
Source File: NodeTests.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hierarchy

import org.apache.spark.Logging
import org.apache.spark.sql.types.{Node, NodeHelpers, StringType}

import scala.collection.mutable.ArrayBuffer

// scalastyle:off magic.number
// scalastyle:off file.size.limit
class NodeTests extends NodeUnitTestSpec with Logging {
    var nodes = ArrayBuffer[Node]()
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 1L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 1L, 2L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 1L, 3L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 2L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 3L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 4L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 4L, 1L))
    nodes += Node(path = null, pathDataType = StringType, ordPath = List(1L, 4L, 2L))
  log.info("Running unit tests for sorting class Node\n")
  nodes.toArray should equal {
    // deterministic generator:
    val myRand = new scala.util.Random(42)

    // take copy of array-buffer, shuffle it
    val shuffled_nodes = myRand.shuffle(nodes.toSeq)

    // shuffled?:
    shuffled_nodes should not equal nodes.toArray

    shuffled_nodes.sorted(NodeHelpers.OrderedNode)
  }
  log.info("Testing function compareToRecursive\n")
  val x = Node(null, null)

  0 should equal {x.compareToRecursive(Seq(), Seq())}
  0 should be > {x.compareToRecursive(Seq(), Seq(1))}
  0 should be < {x.compareToRecursive(Seq(1), Seq())}
  0 should equal {x.compareToRecursive(Seq(1,2), Seq(1,2))}
  0 should be < {x.compareToRecursive(Seq(1,2), Seq(1))}
  0 should be > {x.compareToRecursive(Seq(1), Seq(1,2))}

}
// scalastyle:on magic.number
// scalastyle:on file.size.limit 
Example 56
Source File: CollapseExpandSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.analysis.CollapseExpandSuite.SqlLikeCatalystSourceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.sources.{BaseRelation, CatalystSource, Table}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.util.PlanComparisonUtils._
import org.apache.spark.sql.{GlobalSapSQLContext, Row}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class CollapseExpandSuite extends FunSuite with MockitoSugar with GlobalSapSQLContext {
  case object Leaf extends LeafNode {
    override def output: Seq[Attribute] = Seq.empty
  }

  test("Expansion with a single sequence of projections is correctly collapsed") {
    val expand =
      Expand(
        Seq(Seq('a.string, Literal(1))),
        Seq('a.string, 'gid.int),
        Leaf)

    val collapsed = CollapseExpand(expand)
    assertResult(normalizeExprIds(Project(Seq('a.string, Literal(1) as "gid"), Leaf)))(
      normalizeExprIds(collapsed))
  }

  test("Expansion with multiple projections is correctly collapsed") {
    val expand =
      Expand(
        Seq(
          Seq('a.string, Literal(1)),
          Seq('b.string, Literal(1))),
        Seq('a.string, 'gid1.int, 'b.string, 'gid2.int),
        Leaf)

    val collapsed = CollapseExpand(expand)
    assertResult(
      normalizeExprIds(
        Project(Seq(
            'a.string,
            Literal(1) as "gid1",
            'b.string,
            Literal(1) as "gid2"),
          Leaf)))(normalizeExprIds(collapsed))
  }

  test("Expand pushdown integration") {
    val relation = mock[SqlLikeCatalystSourceRelation]
    when(relation.supportsLogicalPlan(any[Expand]))
      .thenReturn(true)
    when(relation.isMultiplePartitionExecution(any[Seq[CatalystSource]]))
      .thenReturn(true)
    when(relation.schema)
      .thenReturn(StructType(StructField("foo", StringType) :: Nil))
    when(relation.relationName)
      .thenReturn("t")
    when(relation.logicalPlanToRDD(any[LogicalPlan]))
      .thenReturn(sc.parallelize(Seq(Row("a", 1), Row("b", 1), Row("a", 1))))

    sqlc.baseRelationToDataFrame(relation).registerTempTable("t")

    val dataFrame = sqlc.sql("SELECT COUNT(DISTINCT foo) FROM t")
    val Seq(Row(ct)) = dataFrame.collect().toSeq

    assertResult(2)(ct)
  }
}

object CollapseExpandSuite {
  abstract class SqlLikeCatalystSourceRelation
    extends BaseRelation
    with Table
    with SqlLikeRelation
    with CatalystSource
} 
Example 57
Source File: ResolveCountDistinctStarSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FunSuite
import org.scalatest.Inside._
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}

import scala.collection.mutable.ArrayBuffer

class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar {
  val persons = new LogicalRelation(new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("age", IntegerType),
      StructField("name", StringType)
    ))
  })

  test("Count distinct star is resolved correctly") {
    val projection = persons.select(UnresolvedAlias(
      AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true)))
    val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection)
    val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer)
                              .apply(stillNotCompletelyResolvedAggregate)
    inside(resolvedAggregate) {
      case Aggregate(Nil,
      ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) =>
        assert(expressions.collect {
          case a:AttributeReference => a.name
        }.toSet == Set("name", "age"))
    }
    assert(resolvedAggregate.resolved)
  }
} 
Example 58
Source File: HiveSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import com.sap.spark.{GlobalSparkContext, WithSapHiveContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.SapHiveContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.scalatest.FunSuite


class HiveSuite
  extends FunSuite
  with GlobalSparkContext
  with WithSapHiveContext {

  val schema = StructType(
    StructField("foo", StringType) ::
    StructField("bar", StringType) :: Nil)

  test("NewSession returns a new SapHiveContext") {
    val hiveContext = sqlc.asInstanceOf[SapHiveContext]
    val newHiveContext = hiveContext.newSession()

    assert(newHiveContext.isInstanceOf[SapHiveContext])
    assert(newHiveContext != hiveContext)
  }

  test("NewSession returns a hive context whose catalog is separated to the current one") {
    val newContext = sqlc.newSession()
    val emptyRdd = newContext.createDataFrame(sc.emptyRDD[Row], schema)
    emptyRdd.registerTempTable("foo")

    assert(!sqlc.tableNames().contains("foo"))
    assert(newContext.tableNames().contains("foo"))
  }
} 
Example 59
Source File: ExcelRelation.scala    From spark-hadoopoffice-ds   with Apache License 2.0 5 votes vote down vote up
package org.zuinnote.spark.office.excel

import scala.collection.JavaConversions._

import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.SQLContext

import org.apache.spark.sql._
import org.apache.spark.rdd.RDD

import org.apache.hadoop.conf._
import org.apache.hadoop.mapreduce._

import org.apache.commons.logging.LogFactory
import org.apache.commons.logging.Log

import org.zuinnote.hadoop.office.format.common.dao._
import org.zuinnote.hadoop.office.format.mapreduce._

import org.zuinnote.spark.office.excel.util.ExcelFile


  override def buildScan: RDD[Row] = {
    // read ExcelRows
    val excelRowsRDD = ExcelFile.load(sqlContext, location, hadoopParams)
    // map to schema
    val schemaFields = schema.fields
    excelRowsRDD.flatMap(excelKeyValueTuple => {
      // map the Excel row data structure to a Spark SQL schema
      val rowArray = new Array[Any](excelKeyValueTuple._2.get.length)
      var i = 0;
      for (x <- excelKeyValueTuple._2.get) { // parse through the SpreadSheetCellDAO
        val spreadSheetCellDAOStructArray = new Array[String](schemaFields.length)
        val currentSpreadSheetCellDAO: Array[SpreadSheetCellDAO] = excelKeyValueTuple._2.get.asInstanceOf[Array[SpreadSheetCellDAO]]
        spreadSheetCellDAOStructArray(0) = currentSpreadSheetCellDAO(i).getFormattedValue
        spreadSheetCellDAOStructArray(1) = currentSpreadSheetCellDAO(i).getComment
        spreadSheetCellDAOStructArray(2) = currentSpreadSheetCellDAO(i).getFormula
        spreadSheetCellDAOStructArray(3) = currentSpreadSheetCellDAO(i).getAddress
        spreadSheetCellDAOStructArray(4) = currentSpreadSheetCellDAO(i).getSheetName
        // add row representing one Excel row
        rowArray(i) = spreadSheetCellDAOStructArray
        i += 1
      }
      Some(Row.fromSeq(rowArray))
    })

  }

} 
Example 60
Source File: HttpStreamServerClientTest.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.http.HttpStreamClient
import org.junit.Assert
import org.junit.Test
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ByteType
import org.apache.spark.sql.execution.streaming.http.HttpStreamServer
import org.apache.spark.sql.execution.streaming.http.StreamPrinter
import org.apache.spark.sql.execution.streaming.http.HttpStreamServerSideException


class HttpStreamServerClientTest {
	val ROWS1 = Array(Row("hello1", 1, true, 0.1f, 0.1d, 1L, '1'.toByte),
		Row("hello2", 2, false, 0.2f, 0.2d, 2L, '2'.toByte),
		Row("hello3", 3, true, 0.3f, 0.3d, 3L, '3'.toByte));

	val ROWS2 = Array(Row("hello"),
		Row("world"),
		Row("bye"),
		Row("world"));

	@Test
	def testHttpStreamIO() {
		//starts a http server
		val kryoSerializer = new KryoSerializer(new SparkConf());
		val server = HttpStreamServer.start("/xxxx", 8080);

		val spark = SparkSession.builder.appName("testHttpTextSink").master("local[4]")
			.getOrCreate();
		spark.conf.set("spark.sql.streaming.checkpointLocation", "/tmp/");

		val sqlContext = spark.sqlContext;
		import spark.implicits._
		//add a local message buffer to server, with 2 topics registered
		server.withBuffer()
			.addListener(new StreamPrinter())
			.createTopic[(String, Int, Boolean, Float, Double, Long, Byte)]("topic-1")
			.createTopic[String]("topic-2");

		val client = HttpStreamClient.connect("http://localhost:8080/xxxx");
		//tests schema of topics
		val schema1 = client.fetchSchema("topic-1");
		Assert.assertArrayEquals(Array[Object](StringType, IntegerType, BooleanType, FloatType, DoubleType, LongType, ByteType),
			schema1.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		val schema2 = client.fetchSchema("topic-2");
		Assert.assertArrayEquals(Array[Object](StringType),
			schema2.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		//prepare to consume messages
		val sid1 = client.subscribe("topic-1")._1;
		val sid2 = client.subscribe("topic-2")._1;

		//produces some data
		client.sendRows("topic-1", 1, ROWS1);

		val sid4 = client.subscribe("topic-1")._1;
		val sid5 = client.subscribe("topic-2")._1;

		client.sendRows("topic-2", 1, ROWS2);

		//consumes data
		val fetched = client.fetchStream(sid1).map(_.originalRow);
		Assert.assertArrayEquals(ROWS1.asInstanceOf[Array[Object]], fetched.asInstanceOf[Array[Object]]);
		//it is empty now
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid1).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid2).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid4).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);

		client.unsubscribe(sid4);
		try {
			client.fetchStream(sid4);
			//exception should be thrown, because subscriber id is invalidated
			Assert.assertTrue(false);
		}
		catch {
			case e: Throwable ⇒
				e.printStackTrace();
				Assert.assertEquals(classOf[HttpStreamServerSideException], e.getClass);
		}

		server.stop();
	}
} 
Example 61
Source File: StringMapParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.parity.feature

import ml.combust.mleap.core.feature.{HandleInvalid, StringMapModel}
import org.apache.spark.ml.mleap.feature.StringMap
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StringType, StructType}

class StringMapParitySpec extends SparkParityBase {
  val names = Seq("alice", "andy", "kevin")
  val rows = spark.sparkContext.parallelize(Seq.tabulate(3) { i => Row(names(i)) })
  val schema = new StructType().add("name", StringType, nullable = false)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)

  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(
    new StringMap(uid = "string_map", model = new StringMapModel(
      Map("alice" -> 0, "andy" -> 1, "kevin" -> 2)
    )).setInputCol("name").setOutputCol("index"),
    new StringMap(uid = "string_map2", model = new StringMapModel(
      // This map is missing the label "kevin". Exception is thrown unless HandleInvalid.Keep is set.
      Map("alice" -> 0, "andy" -> 1),
      handleInvalid = HandleInvalid.Keep, defaultValue = 1.0
    )).setInputCol("name").setOutputCol("index2")

  )).fit(dataset)
} 
Example 62
Source File: WrappersSpec.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb.spark

import com.example.protos.wrappers._
import org.apache.spark.sql.SparkSession
import org.apache.hadoop.io.ArrayPrimitiveWritable
import scalapb.GeneratedMessageCompanion
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row

import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class WrappersSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
  val spark: SparkSession = SparkSession
    .builder()
    .appName("ScalaPB Demo")
    .master("local[2]")
    .getOrCreate()

  import spark.implicits.StringToColumn

  val data = Seq(
    PrimitiveWrappers(
      intValue = Option(45),
      stringValue = Option("boo"),
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    ),
    PrimitiveWrappers(
      intValue = None,
      stringValue = None,
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    )
  )

  "converting df with primitive wrappers" should "work with primitive implicits" in {
    import ProtoSQL.withPrimitiveWrappers.implicits._
    val df = ProtoSQL.withPrimitiveWrappers.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        IntegerType,
        StringType,
        ArrayType(IntegerType, false),
        ArrayType(StringType, false)
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(45, "boo", Seq(17, 19, 25), Seq("foo", "bar")),
        Row(null, null, Seq(17, 19, 25), Seq("foo", "bar"))
      )
    )
  }

  "converting df with primitive wrappers" should "work with default implicits" in {
    import ProtoSQL.implicits._
    val df = ProtoSQL.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        StructType(Seq(StructField("value", IntegerType, true))),
        StructType(Seq(StructField("value", StringType, true))),
        ArrayType(
          StructType(Seq(StructField("value", IntegerType, true))),
          false
        ),
        ArrayType(
          StructType(Seq(StructField("value", StringType, true))),
          false
        )
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(
          Row(45),
          Row("boo"),
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        ),
        Row(
          null,
          null,
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        )
      )
    )
  }
} 
Example 63
Source File: MultiStreamHandler.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import knolx.Config._
import knolx.KnolXLogger
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{Encoders, SparkSession}


case class CurrentPowerConsumption(kwh: Double)

case class PowerConsumptionStatus(numOfReadings: Long, total: Double, avg: Double, status: String) {
  def compute(newReadings: List[Double]) = {
    val newTotal = newReadings.sum + total
    val newNumOfReadings = numOfReadings + newReadings.size
    val newAvg = newTotal / newNumOfReadings.toDouble

    PowerConsumptionStatus(newNumOfReadings, newTotal, newAvg, "ON")
  }
}

object MultiStreamHandler extends App with KnolXLogger {
  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  val updateStateFunc =
    (deviceId: String, newReadings: Iterator[(String, CurrentPowerConsumption)], state: GroupState[PowerConsumptionStatus]) => {
      val data = newReadings.toList.map { case(_, reading) => reading }.map(_.kwh)

      lazy val initialPowerConsumptionStatus = PowerConsumptionStatus(0L, 0D, 0D, "OFF")
      val currentState = state.getOption.fold(initialPowerConsumptionStatus.compute(data))(_.compute(data))

      val currentStatus =
        if(state.hasTimedOut) {
          // If we do not receive any reading, for a device, we will assume that it is OFF.
          currentState.copy(status = "OFF")
        } else {
          state.setTimeoutDuration("10 seconds")
          currentState
        }

      state.update(currentStatus)
      (deviceId, currentStatus)
    }

  info("Creating Streaming DF...")
  val dataStream =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", topic)
      .option("failOnDataLoss", false)
      .option("includeTimestamp", true)
      .load()

  info("Writing data to Console...")
  import spark.implicits._

  implicit val currentPowerConsumptionEncoder = Encoders.kryo[CurrentPowerConsumption]
  implicit val powerConsumptionStatusEncoder = Encoders.kryo[PowerConsumptionStatus]

  val query =
    dataStream
      .select(col("key").cast(StringType).as("key"), col("value").cast(StringType).as("value"))
      .as[(String, String)]
      .map { case(deviceId, unit) =>
        (deviceId, CurrentPowerConsumption(Option(unit).fold(0D)(_.toDouble)))
      }
      .groupByKey { case(deviceId, _) => deviceId }
      .mapGroupsWithState[PowerConsumptionStatus, (String, PowerConsumptionStatus)](GroupStateTimeout.ProcessingTimeTimeout())(updateStateFunc)
      .toDF("deviceId", "current_status")
      .writeStream
      .format("console")
      .option("truncate", false)
      .outputMode(OutputMode.Update())
      .option("checkpointLocation", checkPointDir)
      .start()

  info("Waiting for the query to terminate...")
  query.awaitTermination()
  query.stop()
} 
Example 64
Source File: StructuredStreamingWordCount.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import com.datastax.driver.core.Cluster
import knolx.Config._
import knolx.KnolXLogger
import knolx.spark.CassandraForeachWriter.writeToCassandra
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, lit, sum}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StringType


object StructuredStreamingWordCount extends App with KnolXLogger {
  val cluster = Cluster.builder.addContactPoints(cassandraHosts).build
  val session = cluster.newSession()

  info("Creating Keypsace and tables in Cassandra...")
  session.execute(s"CREATE KEYSPACE IF NOT EXISTS $keyspace WITH " +
    "replication = {'class':'SimpleStrategy','replication_factor':1};")

  session.execute(s"CREATE TABLE IF NOT EXISTS $keyspace.wordcount ( word text PRIMARY KEY,count int );")

  info("Closing DB connection...")
  session.close()
  session.getCluster.close()

  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  info("Creating Streaming DF...")
  val dataStream =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", topic)
      .load()

  info("Writing data to Cassandra...")
  val query =
    dataStream
      .select(col("value").cast(StringType).as("word"), lit(1).as("count"))
      .groupBy(col("word"))
      .agg(sum("count").as("count"))
      .writeStream
      .outputMode(OutputMode.Update())
      .foreach(writeToCassandra)
      .option("checkpointLocation", checkPointDir)
      .start()

  info("Waiting for the query to terminate...")
  query.awaitTermination()
  query.stop()
} 
Example 65
Source File: EventHubsWriter.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{ AnalysisException, SparkSession }
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{ BinaryType, StringType }
import org.apache.spark.util.Utils


private[eventhubs] object EventHubsWriter extends Logging {

  val BodyAttributeName = "body"
  val PartitionKeyAttributeName = "partitionKey"
  val PartitionIdAttributeName = "partition"
  val PropertiesAttributeName = "properties"

  override def toString: String = "EventHubsWriter"

  private def validateQuery(schema: Seq[Attribute], parameters: Map[String, String]): Unit = {
    schema
      .find(_.name == BodyAttributeName)
      .getOrElse(
        throw new AnalysisException(s"Required attribute '$BodyAttributeName' not found.")
      )
      .dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(
          s"$BodyAttributeName attribute type " +
            s"must be a String or BinaryType.")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      parameters: Map[String, String]
  ): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, parameters)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new EventHubsWriteTask(parameters, schema)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close()
      )
    }
  }
} 
Example 66
Source File: KustoSourceTests.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark

import com.microsoft.kusto.spark.datasource.KustoSourceOptions
import com.microsoft.kusto.spark.utils.{KustoDataSourceUtils => KDSU}
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class KustoSourceTests extends FlatSpec with MockFactory with Matchers with BeforeAndAfterAll {
  private val loggingLevel: Option[String] = Option(System.getProperty("logLevel"))
  if (loggingLevel.isDefined) KDSU.setLoggingLevel(loggingLevel.get)

  private val nofExecutors = 4
  private val spark: SparkSession = SparkSession.builder()
    .appName("KustoSource")
    .master(f"local[$nofExecutors]")
    .getOrCreate()

  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _
  private val cluster: String = "KustoCluster"
  private val database: String = "KustoDatabase"
  private val query: String = "KustoTable"
  private val appId: String = "KustoSinkTestApplication"
  private val appKey: String = "KustoSinkTestKey"
  private val appAuthorityId: String = "KustoSinkAuthorityId"

  override def beforeAll(): Unit = {
    super.beforeAll()

    sc = spark.sparkContext
    sqlContext = spark.sqlContext
  }

  override def afterAll(): Unit = {
    super.afterAll()

    sc.stop()
  }

  "KustoDataSource" should "recognize Kusto and get the correct schema" in {
    val spark: SparkSession = SparkSession.builder()
      .appName("KustoSource")
      .master(f"local[$nofExecutors]")
      .getOrCreate()

    val customSchema = "colA STRING, colB INT"

    val df = spark.sqlContext
      .read
      .format("com.microsoft.kusto.spark.datasource")
      .option(KustoSourceOptions.KUSTO_CLUSTER, cluster)
      .option(KustoSourceOptions.KUSTO_DATABASE, database)
      .option(KustoSourceOptions.KUSTO_QUERY, query)
      .option(KustoSourceOptions.KUSTO_AAD_APP_ID, appId)
      .option(KustoSourceOptions.KUSTO_AAD_APP_SECRET, appKey)
      .option(KustoSourceOptions.KUSTO_AAD_AUTHORITY_ID, appAuthorityId)
      .option(KustoSourceOptions.KUSTO_CUSTOM_DATAFRAME_COLUMN_TYPES, customSchema)
      .load("src/test/resources/")

    val expected = StructType(Array(StructField("colA", StringType, nullable = true),StructField("colB", IntegerType, nullable = true)))
    assert(df.schema.equals(expected))
  }
} 
Example 67
Source File: RawDataWriterHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import com.stratio.sparta.driver.factory.SparkContextFactory
import com.stratio.sparta.driver.step.RawData
import com.stratio.sparta.sdk.pipeline.output.Output
import com.stratio.sparta.sdk.utils.AggregationTime
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.streaming.dstream.DStream


object RawDataWriterHelper {

  def writeRawData(rawData: RawData, outputs: Seq[Output], input: DStream[Row]): Unit = {
    val RawSchema = StructType(Seq(
      StructField(rawData.timeField, TimestampType, nullable = false),
      StructField(rawData.dataField, StringType, nullable = true)))
    val eventTime = AggregationTime.millisToTimeStamp(System.currentTimeMillis())

    input.map(row => Row.merge(Row(eventTime), row))
      .foreachRDD(rdd => {
        if (!rdd.isEmpty()) {
          val rawDataFrame = SparkContextFactory.sparkSessionInstance.createDataFrame(rdd, RawSchema)

          WriterHelper.write(rawDataFrame, rawData.writerOptions, Map.empty[String, String], outputs)
        }
      })
  }
} 
Example 68
Source File: CubeMakerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.cube

import java.sql.Timestamp

import com.github.nscala_time.time.Imports._
import com.stratio.sparta.driver.step.{Cube, CubeOperations, Trigger}
import com.stratio.sparta.driver.writer.WriterOptions
import com.stratio.sparta.plugin.default.DefaultField
import com.stratio.sparta.plugin.cube.field.datetime.DateTimeField
import com.stratio.sparta.plugin.cube.operator.count.CountOperator
import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionValue, DimensionValuesTime, InputFields}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.utils.AggregationTime
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.streaming.TestSuiteBase
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CubeMakerTest extends TestSuiteBase {

  val PreserverOrder = false

  
  def getEventOutput(timestamp: Timestamp, millis: Long):
  Seq[Seq[(DimensionValuesTime, InputFields)]] = {
    val dimensionString = Dimension("dim1", "eventKey", "identity", new DefaultField)
    val dimensionTime = Dimension("minute", "minute", "minute", new DateTimeField)
    val dimensionValueString1 = DimensionValue(dimensionString, "value1")
    val dimensionValueString2 = dimensionValueString1.copy(value = "value2")
    val dimensionValueString3 = dimensionValueString1.copy(value = "value3")
    val dimensionValueTs = DimensionValue(dimensionTime, timestamp)
    val tsMap = Row(timestamp)
    val valuesMap1 = InputFields(Row("value1", timestamp), 1)
    val valuesMap2 = InputFields(Row("value2", timestamp), 1)
    val valuesMap3 = InputFields(Row("value3", timestamp), 1)

    Seq(Seq(
      (DimensionValuesTime("cubeName", Seq(dimensionValueString1, dimensionValueTs)), valuesMap1),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString2, dimensionValueTs)), valuesMap2),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString3, dimensionValueTs)), valuesMap3)
    ))
  }
} 
Example 69
Source File: ParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.transformation

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ParserTest extends WordSpec with Matchers {

  "Parser" should {

    val parserTest = new ParserMock(
      1,
      Some("input"),
      Seq("output"),
      StructType(Seq(StructField("some", StringType))),
      Map()
    )

    "Order must be " in {
      val expected = 1
      val result = parserTest.getOrder
      result should be(expected)
    }

    "Parse must be " in {
      val event = Row("value")
      val expected = Seq(event)
      val result = parserTest.parse(event)
      result should be(expected)
    }

    "checked fields not be contained in outputs must be " in {
      val keyMap = Map("field" -> "value")
      val expected = Map()
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "checked fields are contained in outputs must be " in {
      val keyMap = Map("output" -> "value")
      val expected = keyMap
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "classSuffix must be " in {
      val expected = "Parser"
      val result = Parser.ClassSuffix
      result should be(expected)
    }
  }
} 
Example 70
Source File: CassandraOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.cassandra

import java.io.{Serializable => JSerializable}

import com.datastax.spark.connector.cql.CassandraConnector
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class CassandraOutputTest extends FlatSpec with Matchers with MockitoSugar with AnswerSugar {

  val s = "sum"
  val properties = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))

  "getSparkConfiguration" should "return a Seq with the configuration" in {
    val configuration = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))
    val cass = CassandraOutput.getSparkConfiguration(configuration)

    cass should be(List(("spark.cassandra.connection.host", "127.0.0.1"), ("spark.cassandra.connection.port", "9042")))
  }

  "getSparkConfiguration" should "return all cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("sparkProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(true)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(true)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }

  "getSparkConfiguration" should "not return cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("hadoopProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(false)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(false)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }
} 
Example 71
Source File: OperatorEntityCountTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class OperatorEntityCountTest extends WordSpec with Matchers {

  "EntityCount" should {
    val props = Map(
      "inputField" -> "inputField".asInstanceOf[JSerializable],
      "split" -> ",".asInstanceOf[JSerializable])
    val schema = StructType(Seq(StructField("inputField", StringType)))
    val entityCount = new OperatorEntityCountMock("op1", schema, props)
    val inputFields = Row("hello,bye")

    "Return the associated precision name" in {
      val expected = Option(Seq("hello", "bye"))
      val result = entityCount.processMap(inputFields)
      result should be(expected)
    }

    "Return empty list" in {
      val expected = None
      val result = entityCount.processMap(Row())
      result should be(expected)
    }
  }
} 
Example 72
Source File: MorphlinesParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.morphline

import java.io.Serializable

import com.stratio.sparta.sdk.pipeline.input.Input
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}


@RunWith(classOf[JUnitRunner])
class MorphlinesParserTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll {

  val morphlineConfig = """
          id : test1
          importCommands : ["org.kitesdk.**"]
          commands: [
          {
              readJson {},
          }
          {
              extractJsonPaths {
                  paths : {
                      col1 : /col1
                      col2 : /col2
                  }
              }
          }
          {
            java {
              code : "return child.process(record);"
            }
          }
          {
              removeFields {
                  blacklist:["literal:_attachment_body"]
              }
          }
          ]
                        """
  val inputField = Some(Input.RawDataKey)
  val outputsFields = Seq("col1", "col2")
  val props: Map[String, Serializable] = Map("morphline" -> morphlineConfig)

  val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType)))

  val parser = new MorphlinesParser(1, inputField, outputsFields, schema, props)

  "A MorphlinesParser" should {

    "parse a simple json" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }

    "parse a simple json removing raw" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row("hello", "world"))

      result should be eq(expected)
    }

    "exclude not configured fields" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word",
            "col3":"!"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }
  }
} 
Example 73
Source File: DateTimeParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.datetime

import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DateTimeParserTest extends WordSpecLike with Matchers {

  val inputField = Some("ts")
  val outputsFields = Seq("ts")

  //scalastyle:off
  "A DateTimeParser" should {
    "parse unixMillis to string" in {
      val input = Row(1416330788000L)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis"))
          .parse(input)

      val expected = Seq(Row(1416330788000L, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix"))
          .parse(input)

      val expected = Seq(Row(1416330788, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string removing raw" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix",
          "removeInputField" -> JsoneyString.apply("true")))
          .parse(input)

      val expected = Seq(Row("1416330788000"))

      assertResult(result)(expected)
    }

    "not parse anything if the field does not match" in {
      val input = Row("1212")
      val schema = StructType(Seq(StructField("otherField", StringType)))

      an[IllegalStateException] should be thrownBy new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis")).parse(input)
    }

    "not parse anything and generate a new Date" in {
      val input = Row("anything")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "autoGenerated"))
          .parse(input)

      assertResult(result.head.size)(2)
    }

    "Auto generated if inputFormat does not exist" in {
      val input = Row("1416330788")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map()).parse(input)

      assertResult(result.head.size)(2)
    }

    "parse dateTime in hive format" in {
      val input = Row("2015-11-08 15:58:58")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "hive"))
          .parse(input)

      val expected = Seq(Row("2015-11-08 15:58:58", "1446998338000"))

      assertResult(result)(expected)
    }
  }
} 
Example 74
Source File: StatisticsTest.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.statistics

import java.io.ByteArrayOutputStream

import scala.collection.mutable.ArrayBuffer

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BaseOrdering
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache
import org.apache.spark.sql.execution.datasources.oap.index.RangeInterval
import org.apache.spark.sql.execution.datasources.oap.utils.{NonNullKeyReader, NonNullKeyWriter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.unsafe.types.UTF8String

abstract class StatisticsTest extends SparkFunSuite with BeforeAndAfterEach {

  protected def rowGen(i: Int): InternalRow = InternalRow(i, UTF8String.fromString(s"test#$i"))

  protected lazy val schema: StructType = StructType(StructField("a", IntegerType)
    :: StructField("b", StringType) :: Nil)
  @transient
  protected lazy val nnkw: NonNullKeyWriter = new NonNullKeyWriter(schema)
  @transient
  protected lazy val nnkr: NonNullKeyReader = new NonNullKeyReader(schema)
  @transient
  protected lazy val ordering: BaseOrdering = GenerateOrdering.create(schema)
  @transient
  protected lazy val partialOrdering: BaseOrdering =
    GenerateOrdering.create(StructType(schema.dropRight(1)))
  protected var out: ByteArrayOutputStream = _

  protected var intervalArray: ArrayBuffer[RangeInterval] = new ArrayBuffer[RangeInterval]()

  override def beforeEach(): Unit = {
    out = new ByteArrayOutputStream(8000)
  }

  override def afterEach(): Unit = {
    out.close()
    intervalArray.clear()
  }

  protected def generateInterval(
      start: InternalRow, end: InternalRow,
      startInclude: Boolean, endInclude: Boolean): Unit = {
    intervalArray.clear()
    intervalArray.append(new RangeInterval(start, end, startInclude, endInclude))
  }

  protected def checkInternalRow(row1: InternalRow, row2: InternalRow): Unit = {
    val res = row1 == row2 // it works..
    assert(res, s"row1: $row1 does not match $row2")
  }

  protected def wrapToFiberCache(out: ByteArrayOutputStream): FiberCache = {
    val bytes = out.toByteArray
    FiberCache(bytes)
  }
} 
Example 75
Source File: DeltaByteArrayEncoderSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.scalacheck.{Arbitrary, Gen, Properties}
import org.scalacheck.Prop.forAll
import org.scalatest.prop.Checkers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter
import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

class DeltaByteArrayEncoderCheck extends Properties("DeltaByteArrayEncoder") {

  private val rowCountInEachGroup = Gen.choose(1, 1024)
  private val rowCountInLastGroup = Gen.choose(1, 1024)
  private val groupCount = Gen.choose(1, 100)

  property("Encoding/Decoding String Type") = forAll { (values: Array[String]) =>

    forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) {
      (rowCount, lastCount, groupCount) =>
        if (values.nonEmpty) {
          // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding'
          // Normally, the test case should be:
          // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class)
          // Validate if 'values' and 'decoded values' are identical.
          // But ColumnValues only support read value form DataFile. So, we have to use another way
          // to validate.
          val referenceFiberBuilder = StringFiberBuilder(rowCount, 0)
          val fiberBuilder = DeltaByteArrayFiberBuilder(rowCount, 0, StringType)
          val fiberParser = DeltaByteArrayDataFiberParser(
            new OapDataFileMetaV1(rowCountInEachGroup = rowCount), StringType)
          !(0 until groupCount).exists { group =>
            // If lastCount > rowCount, assume lastCount = rowCount
            val count = if (group < groupCount - 1) {
              rowCount
            } else if (lastCount > rowCount) {
              rowCount
            } else {
              lastCount
            }
            (0 until count).foreach { row =>
              fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length))))
              referenceFiberBuilder
                .append(InternalRow(UTF8String.fromString(values(row % values.length))))
            }
            val bytes = fiberBuilder.build().fiberData
            val parsedBytes = fiberParser.parse(bytes, count)
            val referenceBytes = referenceFiberBuilder.build().fiberData
            referenceFiberBuilder.clear()
            fiberBuilder.clear()
            assert(parsedBytes.length == referenceBytes.length)
            parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2)
          }
        } else true
    }
  }
}

class DeltaByteArrayEncoderSuite extends SparkFunSuite with Checkers {

  test("Check Encoding/Decoding") {
    check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck()))
  }
} 
Example 76
Source File: DictionaryBasedEncoderSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.io

import org.apache.parquet.bytes.BytesInput
import org.apache.parquet.column.page.DictionaryPage
import org.apache.parquet.column.values.dictionary.PlainValuesDictionary.PlainBinaryDictionary
import org.scalacheck.{Arbitrary, Gen, Properties}
import org.scalacheck.Prop.forAll
import org.scalatest.prop.Checkers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.oap.adapter.PropertiesAdapter
import org.apache.spark.sql.execution.datasources.oap.filecache.StringFiberBuilder
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

class DictionaryBasedEncoderCheck extends Properties("DictionaryBasedEncoder") {
  private val rowCountInEachGroup = Gen.choose(1, 1024)
  private val rowCountInLastGroup = Gen.choose(1, 1024)
  private val groupCount = Gen.choose(1, 100)

  property("Encoding/Decoding String Type") = forAll { (values: Array[String]) =>

    forAll(rowCountInEachGroup, rowCountInLastGroup, groupCount) {
      (rowCount, lastCount, groupCount) =>
        if (values.nonEmpty) {
          // This is the 'PLAIN' FiberBuilder to validate the 'Encoding/Decoding'
          // Normally, the test case should be:
          // values => encoded bytes => decoded bytes => decoded values (Using ColumnValues class)
          // Validate if 'values' and 'decoded values' are identical.
          // But ColumnValues only support read value form DataFile. So, we have to use another way
          // to validate.
          val referenceFiberBuilder = StringFiberBuilder(rowCount, 0)
          val fiberBuilder = PlainBinaryDictionaryFiberBuilder(rowCount, 0, StringType)
          !(0 until groupCount).exists { group =>
            // If lastCount > rowCount, assume lastCount = rowCount
            val count =
              if (group < groupCount - 1) {
                rowCount
              } else if (lastCount > rowCount) {
                rowCount
              } else {
                lastCount
              }
            (0 until count).foreach { row =>
              fiberBuilder.append(InternalRow(UTF8String.fromString(values(row % values.length))))
              referenceFiberBuilder
                .append(InternalRow(UTF8String.fromString(values(row % values.length))))
            }
            val bytes = fiberBuilder.build().fiberData
            val dictionary = new PlainBinaryDictionary(
              new DictionaryPage(
                BytesInput.from(fiberBuilder.buildDictionary),
                fiberBuilder.getDictionarySize,
                org.apache.parquet.column.Encoding.PLAIN))
            val fiberParser = PlainDictionaryFiberParser(
              new OapDataFileMetaV1(rowCountInEachGroup = rowCount), dictionary, StringType)
            val parsedBytes = fiberParser.parse(bytes, count)
            val referenceBytes = referenceFiberBuilder.build().fiberData
            referenceFiberBuilder.clear()
            referenceFiberBuilder.resetDictionary()
            fiberBuilder.clear()
            fiberBuilder.resetDictionary()
            assert(parsedBytes.length == referenceBytes.length)
            parsedBytes.zip(referenceBytes).exists(byte => byte._1 != byte._2)
          }
        } else {
          true
        }
    }
  }
}

class DictionaryBasedEncoderSuite extends SparkFunSuite with Checkers {

  test("Check Encoding/Decoding") {
    check(PropertiesAdapter.getProp(new DictionaryBasedEncoderCheck()))
  }
} 
Example 77
Source File: MLPipelineTrackerIT.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.ml

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.MinMaxScaler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.Matchers
import com.hortonworks.spark.atlas._
import com.hortonworks.spark.atlas.types._
import com.hortonworks.spark.atlas.TestUtils._

class MLPipelineTrackerIT extends BaseResourceIT with Matchers with WithHiveSupport {
  private val atlasClient = new RestAtlasClient(atlasClientConf)

  def clusterName: String = atlasClientConf.get(AtlasClientConf.CLUSTER_NAME)

  def getTableEntity(tableName: String): SACAtlasEntityWithDependencies = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)
    internal.sparkTableToEntity(tableDefinition, clusterName, Some(dbDefinition))
  }

  // Enable it to run integrated test.
  it("pipeline and pipeline model") {
    val uri = "hdfs://"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, modelDirEntity))

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("features_scaled")
      .setMin(0.0)
      .setMax(3.0)
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model = pipeline.fit(df)

    pipeline.write.overwrite().save(pipelineDir)

    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, pipelineEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(modelDirEntity, modelEntity))

    val tableEntities1 = getTableEntity("chris1")
    val tableEntities2 = getTableEntity("chris2")

    atlasClient.createEntitiesWithDependencies(tableEntities1)
    atlasClient.createEntitiesWithDependencies(tableEntities2)

  }
} 
Example 78
Source File: MLAtlasEntityUtilsSuite.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.types

import java.io.File

import org.apache.atlas.{AtlasClient, AtlasConstants}
import org.apache.atlas.model.instance.AtlasEntity
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.MinMaxScaler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.{FunSuite, Matchers}
import com.hortonworks.spark.atlas.TestUtils._
import com.hortonworks.spark.atlas.{AtlasUtils, WithHiveSupport}

class MLAtlasEntityUtilsSuite extends FunSuite with Matchers with WithHiveSupport {

  def getTableEntity(tableName: String): AtlasEntity = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)

    val tableEntities = internal.sparkTableToEntity(
      tableDefinition, AtlasConstants.DEFAULT_CLUSTER_NAME, Some(dbDefinition))
    val tableEntity = tableEntities.entity

    tableEntity
  }

  test("pipeline, pipeline model, fit and transform") {
    val uri = "/"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    pipelineDirEntity.entity.getAttribute("uri") should be (uri)
    pipelineDirEntity.entity.getAttribute("directory") should be (pipelineDir)
    pipelineDirEntity.dependencies.length should be (0)

    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)
    modelDirEntity.entity.getAttribute("uri") should be (uri)
    modelDirEntity.entity.getAttribute("directory") should be (modelDir)
    modelDirEntity.dependencies.length should be (0)

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("features_scaled")
      .setMin(0.0)
      .setMax(3.0)
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model = pipeline.fit(df)

    pipeline.write.overwrite().save(pipelineDir)

    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)
    pipelineEntity.entity.getTypeName should be (metadata.ML_PIPELINE_TYPE_STRING)
    pipelineEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (
      pipeline.uid)
    pipelineEntity.entity.getAttribute("name") should be (pipeline.uid)
    pipelineEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(pipelineDirEntity.entity, useGuid = false))
    pipelineEntity.dependencies should be (Seq(pipelineDirEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)
    val modelUid = model.uid.replaceAll("pipeline", "model")
    modelEntity.entity.getTypeName should be (metadata.ML_MODEL_TYPE_STRING)
    modelEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (modelUid)
    modelEntity.entity.getAttribute("name") should be (modelUid)
    modelEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(modelDirEntity.entity, useGuid = false))

    modelEntity.dependencies should be (Seq(modelDirEntity))

    FileUtils.deleteDirectory(new File("tmp"))
  }
} 
Example 79
Source File: ProxyFeedback.scala    From oni-ml   with Apache License 2.0 5 votes vote down vote up
package org.opennetworkinsight.proxy

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StructType, StructField, StringType}
import scala.io.Source

import org.opennetworkinsight.proxy.ProxySchema._


object ProxyFeedback {

  
  def loadFeedbackDF(sc: SparkContext,
                     sqlContext: SQLContext,
                     feedbackFile: String,
                     duplicationFactor: Int): DataFrame = {


    val feedbackSchema = StructType(
      List(StructField(Date, StringType, nullable= true),
        StructField(Time, StringType, nullable= true),
        StructField(ClientIP, StringType, nullable= true),
        StructField(Host, StringType, nullable= true),
        StructField(ReqMethod, StringType, nullable= true),
        StructField(UserAgent, StringType, nullable= true),
        StructField(ResponseContentType, StringType, nullable= true),
        StructField(RespCode, StringType, nullable= true),
        StructField(FullURI, StringType, nullable= true)))

    if (new java.io.File(feedbackFile).exists) {

      val dateIndex = 0
      val timeIndex = 1
      val clientIpIndex = 2
      val hostIndex = 3
      val reqMethodIndex = 4
      val userAgentIndex = 5
      val resContTypeIndex = 6
      val respCodeIndex = 11
      val fullURIIndex = 18

      val fullURISeverityIndex = 22

      val lines = Source.fromFile(feedbackFile).getLines().toArray.drop(1)
      val feedback: RDD[String] = sc.parallelize(lines)

      sqlContext.createDataFrame(feedback.map(_.split("\t"))
        .filter(row => row(fullURISeverityIndex).trim.toInt == 3)
        .map(row => Row.fromSeq(List(row(dateIndex),
          row(timeIndex),
          row(clientIpIndex),
          row(hostIndex),
          row(reqMethodIndex),
          row(userAgentIndex),
          row(resContTypeIndex),
          row(respCodeIndex),
          row(fullURIIndex))))
        .flatMap(row => List.fill(duplicationFactor)(row)), feedbackSchema)
        .select(Date, Time, ClientIP, Host, ReqMethod, UserAgent, ResponseContentType, RespCode, FullURI)
    } else {
      sqlContext.createDataFrame(sc.emptyRDD[Row], feedbackSchema)
    }
  }
} 
Example 80
Source File: StreamingQueryListenerSampleJob.scala    From spark-monitoring   with MIT License 5 votes vote down vote up
package com.microsoft.pnp.samplejob

import com.microsoft.pnp.logging.Log4jConfiguration
import com.microsoft.pnp.util.TryWith
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.UserMetricsSystems
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.window
import org.apache.spark.sql.types.{StringType, StructType, TimestampType}

object StreamingQueryListenerSampleJob extends Logging {

  private final val METRICS_NAMESPACE = "streamingquerylistenersamplejob"
  private final val COUNTER_NAME = "rowcounter"

  def main(args: Array[String]): Unit = {

    // Configure our logging
    TryWith(getClass.getResourceAsStream("/com/microsoft/pnp/samplejob/log4j.properties")) {
      stream => {
        Log4jConfiguration.configure(stream)
      }
    }

    logTrace("Trace message from StreamingQueryListenerSampleJob")
    logDebug("Debug message from StreamingQueryListenerSampleJob")
    logInfo("Info message from StreamingQueryListenerSampleJob")
    logWarning("Warning message from StreamingQueryListenerSampleJob")
    logError("Error message from StreamingQueryListenerSampleJob")

    val spark = SparkSession
      .builder
      .getOrCreate

    import spark.implicits._

    // this path has sample files provided by databricks for trying out purpose
    val inputPath = "/databricks-datasets/structured-streaming/events/"

    val jsonSchema = new StructType().add("time", TimestampType).add("action", StringType)

    val driverMetricsSystem = UserMetricsSystems
        .getMetricSystem(METRICS_NAMESPACE, builder => {
          builder.registerCounter(COUNTER_NAME)
        })

    driverMetricsSystem.counter(COUNTER_NAME).inc

    // Similar to definition of staticInputDF above, just using `readStream` instead of `read`
    val streamingInputDF =
      spark
        .readStream // `readStream` instead of `read` for creating streaming DataFrame
        .schema(jsonSchema) // Set the schema of the JSON data
        .option("maxFilesPerTrigger", 1) // Treat a sequence of files as a stream by picking one file at a time
        .json(inputPath)

    driverMetricsSystem.counter(COUNTER_NAME).inc(5)

    val streamingCountsDF =
      streamingInputDF
        .groupBy($"action", window($"time", "1 hour"))
        .count()

    // Is this DF actually a streaming DF?
    streamingCountsDF.isStreaming

    driverMetricsSystem.counter(COUNTER_NAME).inc(10)

    val query =
      streamingCountsDF
        .writeStream
        .format("memory") // memory = store in-memory table (for testing only in Spark 2.0)
        .queryName("counts") // counts = name of the in-memory table
        .outputMode("complete") // complete = all the counts should be in the table
        .start()
  }
} 
Example 81
Source File: UserData.scala    From Machine-Learning-with-Spark-Second-Edition   with MIT License 5 votes vote down vote up
package org.sparksamples.df
//import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType};
package object UserData {
  def main(args: Array[String]): Unit = {
    val customSchema = StructType(Array(
      StructField("no", IntegerType, true),
      StructField("age", StringType, true),
      StructField("gender", StringType, true),
      StructField("occupation", StringType, true),
      StructField("zipCode", StringType, true)));
    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession
      .builder()
      .appName("SparkUserData").config(spConfig)
      .getOrCreate()

    val user_df = spark.read.format("com.databricks.spark.csv")
      .option("delimiter", "|").schema(customSchema)
      .load("/home/ubuntu/work/ml-resources/spark-ml/data/ml-100k/u.user")
    val first = user_df.first()
    println("First Record : " + first)

    val num_genders = user_df.groupBy("gender").count().count()
    val num_occupations = user_df.groupBy("occupation").count().count()
    val num_zipcodes = user_df.groupBy("zipCode").count().count()

    println("num_users : " + user_df.count())
    println("num_genders : "+ num_genders)
    println("num_occupations : "+ num_occupations)
    println("num_zipcodes: " + num_zipcodes)
    println("Distribution by Occupation")
    println(user_df.groupBy("occupation").count().show())

  }
} 
Example 82
Source File: Util.scala    From Machine-Learning-with-Spark-Second-Edition   with MIT License 5 votes vote down vote up
package org.sparksamples

import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StringType, StructField, StructType}


object Util {
  val PATH = "/home/ubuntu/work/spark-2.0.0-bin-hadoop2.7/"
  val DATA_PATH= "../../../data/ml-100k"
  val PATH_MOVIES = DATA_PATH + "/u.item"

  def reduceDimension2(x: Vector) : String= {
    var i = 0
    var l = x.toArray.size
    var l_2 = l/2.toInt
    var x_ = 0.0
    var y_ = 0.0

    for(i <- 0 until l_2) {
      x_ += x(i).toDouble
    }
    for(i <- (l_2 + 1) until l) {
      y_ += x(i).toDouble
    }
    var t = x_ + "," + y_
    return t
  }

  def getMovieDataDF(spark : SparkSession) : DataFrame = {

    //1|Toy Story (1995)|01-Jan-1995||http://us.imdb.com/M/title-exact?Toy%20Story%20(1995)
    // |0|0|0|1|1|1|0|0|0|0|0|0|0|0|0|0|0|0|0
    val customSchema = StructType(Array(
      StructField("id", StringType, true),
      StructField("name", StringType, true),
      StructField("date", StringType, true),
      StructField("url", StringType, true)));
    val movieDf = spark.read.format("com.databricks.spark.csv")
      .option("delimiter", "|").schema(customSchema)
      .load(PATH_MOVIES)
    return movieDf
  }

} 
Example 83
Source File: PrettifyTest.scala    From spark-testing-base   with Apache License 2.0 5 votes vote down vote up
package com.holdenkarau.spark.testing

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalacheck.Gen
import org.scalacheck.Prop._
import org.scalacheck.util.Pretty
import org.scalatest.FunSuite
import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException
import org.scalatest.prop.Checkers

class PrettifyTest extends FunSuite with SharedSparkContext with Checkers with Prettify {
  implicit val propertyCheckConfig = PropertyCheckConfig(minSize = 2, maxSize = 2)

  test("pretty output of DataFrame's check") {
    val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType)))
    val sqlContext = new SQLContext(sc)
    val nameGenerator = new Column("name", Gen.const("Holden Hanafy"))
    val ageGenerator = new Column("age", Gen.const(20))

    val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)

    val actual = runFailingCheck(dataframeGen.arbitrary)
    val expected =
      Some("arg0 = <DataFrame: schema = [name: string, age: int], size = 2, values = ([Holden Hanafy,20], [Holden Hanafy,20])>")
    assert(actual == expected)
  }

  test("pretty output of RDD's check") {
    val rddGen = RDDGenerator.genRDD[(String, Int)](sc) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age
    }

    val actual = runFailingCheck(rddGen)
    val expected =
      Some("""arg0 = <RDD: size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)
  }

  test("pretty output of Dataset's check") {
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val datasetGen = DatasetGenerator.genDataset[(String, Int)](sqlContext) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age
    }

    val actual = runFailingCheck(datasetGen)
    val expected =
      Some("""arg0 = <Dataset: schema = [_1: string, _2: int], size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)
  }

  private def runFailingCheck[T](genUnderTest: Gen[T])(implicit p: T => Pretty) = {
    val property = forAll(genUnderTest)(_ => false)
    val e = intercept[GeneratorDrivenPropertyCheckFailedException] {
      check(property)
    }
    takeSecondToLastLine(e.message)
  }

  private def takeSecondToLastLine(msg: Option[String]) =
    msg.flatMap(_.split("\n").toList.reverse.tail.headOption.map(_.trim))

} 
Example 84
Source File: PartitionAndSleepWorkload.scala    From spark-bench   with Apache License 2.0 5 votes vote down vote up
package com.ibm.sparktc.sparkbench.workload.exercise

import com.ibm.sparktc.sparkbench.workload.{Workload, WorkloadDefaults}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import com.ibm.sparktc.sparkbench.utils.GeneralFunctions._
import com.ibm.sparktc.sparkbench.utils.SaveModes
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

object PartitionAndSleepWorkload extends WorkloadDefaults {
  val name = "timedsleep"
  val partitions: Int = 48
  val sleepms: Long = 12000L

  def apply(m: Map[String, Any]) = new PartitionAndSleepWorkload(
    input = None,
    output = None,
    partitions = getOrDefault[Int](m, "partitions", partitions),
    sleepMS = getOrDefault[Long](m, "sleepms", sleepms, any2Long))
}

case class PartitionAndSleepWorkload(input: Option[String] = None,
                                     output: Option[String] = None,
                                     saveMode: String = SaveModes.error,
                                     partitions: Int,
                                     sleepMS: Long) extends Workload {

  def doStuff(spark: SparkSession): (Long, Unit) = time {

    val ms = sleepMS
    val stuff: RDD[Int] = spark.sparkContext.parallelize(0 until partitions * 100, partitions)

    val cool: RDD[(Int, Int)] = stuff.map { i =>
      Thread.sleep(ms)
      (i % 10, i + 42)
    }

    val yeah = cool.reduceByKey(_ + _)
    yeah.collect()
  }

  override def doWorkload(df: Option[DataFrame] = None, spark: SparkSession): DataFrame = {
    val (t, _) = doStuff(spark)

    val schema = StructType(
      List(
        StructField("name", StringType, nullable = false),
        StructField("timestamp", LongType, nullable = false),
        StructField("runtime", LongType, nullable = false)
      )
    )

    val timeList = spark.sparkContext.parallelize(Seq(Row("timedsleep", System.currentTimeMillis(), t)))

    spark.createDataFrame(timeList, schema)
  }
} 
Example 85
Source File: GraphDataGen.scala    From spark-bench   with Apache License 2.0 5 votes vote down vote up
package com.ibm.sparktc.sparkbench.datageneration

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import com.ibm.sparktc.sparkbench.utils.{SaveModes, SparkBenchException}
import com.ibm.sparktc.sparkbench.utils.GeneralFunctions.{any2Long, getOrDefault, getOrThrow, time}
import com.ibm.sparktc.sparkbench.workload.{Workload, WorkloadDefaults}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.apache.spark.graphx.util.GraphGenerators

object GraphDataGen extends WorkloadDefaults {

  val name = "graph-data-generator"
  val defaultMu = 4.0
  val defaultSigma = 1.3
  val defaultSeed = -1L
  val defaultNumOfPartitions = 0

  override def apply(m: Map[String, Any]): GraphDataGen = {
      val numVertices = getOrThrow(m, "vertices").asInstanceOf[Int]
      val mu = getOrDefault[Double](m, "mu", defaultMu)
      val sigma = getOrDefault[Double](m, "sigma", defaultSigma)
      val numPartitions = getOrDefault[Int](m, "partitions", defaultNumOfPartitions)
      val seed = getOrDefault[Long](m, "seed", defaultSeed, any2Long)
      val output = {
        val str = getOrThrow(m, "output").asInstanceOf[String]
        val s = verifySuitabilityOfOutputFileFormat(str)
        Some(s)
      }
    val saveMode = getOrDefault[String](m, "save-mode", SaveModes.error)

    new GraphDataGen(
      numVertices = numVertices,
      input = None,
      output = output,
      saveMode = saveMode,
      mu = mu,
      sigma = sigma,
      seed = seed,
      numPartitions = numPartitions
    )
  }

  
  private[datageneration] def verifySuitabilityOfOutputFileFormat(str: String): String = {
    val strArr: Array[String] = str.split('.')

    (strArr.length, strArr.last) match {
      case (1, _) => throw SparkBenchException("Output file for GraphDataGen must have \".txt\" as the file extension." +
        "Please modify your config file.")
      case (2, "txt") => str
      case (_, _) => throw SparkBenchException("Due to limitations of the GraphX GraphLoader, " +
        "the graph data generators may only save files as \".txt\"." +
        "Please modify your config file.")
    }
  }

}

case class GraphDataGen (
                          numVertices: Int,
                          input: Option[String] = None,
                          output: Option[String],
                          saveMode: String,
                          mu: Double = 4.0,
                          sigma: Double = 1.3,
                          seed: Long = 1,
                          numPartitions: Int = 0
                        ) extends Workload {

  override def doWorkload(df: Option[DataFrame] = None, spark: SparkSession): DataFrame = {
    val timestamp = System.currentTimeMillis()
    val (generateTime, graph) = time(GraphGenerators.logNormalGraph(spark.sparkContext, numVertices, numPartitions, mu, sigma))
    val (convertTime, out) = time(graph.edges.map(e => s"${e.srcId.toString} ${e.dstId}"))
    val (saveTime, _) = time(out.saveAsTextFile(output.get))

    val timeResultSchema = StructType(
      List(
        StructField("name", StringType, nullable = false),
        StructField("timestamp", LongType, nullable = false),
        StructField("generate", LongType, nullable = true),
        StructField("convert", LongType, nullable = true),
        StructField("save", LongType, nullable = true),
        StructField("total_runtime", LongType, nullable = false)
      )
    )
    val total = generateTime + convertTime + saveTime
    val timeList = spark.sparkContext.parallelize(Seq(Row(GraphDataGen.name, timestamp, generateTime, convertTime, saveTime, total)))
    spark.createDataFrame(timeList, timeResultSchema)
  }
} 
Example 86
Source File: LinearRegressionDataGen.scala    From spark-bench   with Apache License 2.0 5 votes vote down vote up
package com.ibm.sparktc.sparkbench.datageneration.mlgenerator

import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import com.ibm.sparktc.sparkbench.utils.{SaveModes, SparkBenchException}
import com.ibm.sparktc.sparkbench.utils.GeneralFunctions.{getOrDefault, getOrThrow, time}
import com.ibm.sparktc.sparkbench.utils.SparkFuncs.writeToDisk
import com.ibm.sparktc.sparkbench.workload.{Workload, WorkloadDefaults}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

object LinearRegressionDataGen extends WorkloadDefaults {
  val name = "data-generation-lr"
  // Application parameters #1million points have 200M data size
  val numOfExamples: Int = 40000
  val numOfFeatures: Int = 4
  val eps: Double = 0.5
  val intercepts: Double = 0.1
  val numOfPartitions: Int = 10
  val maxIteration: Int = 3
  override def apply(m: Map[String, Any]) = new LinearRegressionDataGen(
    numRows = getOrThrow(m, "rows").asInstanceOf[Int],
    numCols = getOrThrow(m, "cols").asInstanceOf[Int],
    output = Some(getOrThrow(m, "output").asInstanceOf[String]),
    saveMode = getOrDefault[String](m, "save-mode", SaveModes.error),
    eps = getOrDefault[Double](m, "eps", eps),
    intercepts = getOrDefault[Double](m, "intercepts", intercepts),
    numPartitions = getOrDefault[Int](m, "partitions", numOfPartitions)
  )
}

case class LinearRegressionDataGen (
                                      numRows: Int,
                                      numCols: Int,
                                      input: Option[String] = None,
                                      output: Option[String],
                                      saveMode: String,
                                      eps: Double,
                                      intercepts: Double,
                                      numPartitions: Int
                                   ) extends Workload {

  override def doWorkload(df: Option[DataFrame] = None, spark: SparkSession): DataFrame = {

    val timestamp = System.currentTimeMillis()

    val (generateTime, data): (Long, RDD[LabeledPoint]) = time {
      LinearDataGenerator.generateLinearRDD(
        spark.sparkContext,
        numRows,
        numCols,
        eps,
        numPartitions,
        intercepts
      )
    }

    import spark.implicits._
    val (convertTime, dataDF) = time {
      data.toDF
    }

    val (saveTime, _) = time {
      val outputstr = output.get
      if(outputstr.endsWith(".csv")) throw SparkBenchException("LabeledPoints cannot be saved to CSV. Please try outputting to Parquet instead.")
      writeToDisk(output.get, saveMode, dataDF, spark)
    }//TODO you can't output this to CSV. Parquet is fine

    val timeResultSchema = StructType(
      List(
        StructField("name", StringType, nullable = false),
        StructField("timestamp", LongType, nullable = false),
        StructField("generate", LongType, nullable = true),
        StructField("convert", LongType, nullable = true),
        StructField("save", LongType, nullable = true),
        StructField("total_runtime", LongType, nullable = false)
      )
    )

    val total = generateTime + convertTime + saveTime

    val timeList = spark.sparkContext.parallelize(Seq(Row("kmeans", timestamp, generateTime, convertTime, saveTime, total)))

    spark.createDataFrame(timeList, timeResultSchema)

  }
} 
Example 87
Source File: NGram.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  @Since("1.5.0")
  def getN: Int = $(n)

  setDefault(n -> 2)

  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

@Since("1.6.0")
object NGram extends DefaultParamsReadable[NGram] {

  @Since("1.6.0")
  override def load(path: String): NGram = super.load(path)
} 
Example 88
Source File: MapDataSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 89
Source File: ScalaUDFSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 90
Source File: RewriteDistinctAggregatesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 91
Source File: resources.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 92
Source File: WholeStageCodegenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 93
Source File: GroupedIteratorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 94
Source File: DDLSourceLoadSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 95
Source File: ForecastPipelineStage.scala    From uberdata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml

import eleflow.uberdata.IUberdataForecastUtil
import org.apache.spark.ml.param.shared.{HasNFutures, HasPredictionCol, HasValidationCol}
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.types.{StructType, StringType, StructField, MapType}


trait ForecastPipelineStage
    extends PipelineStage
    with HasNFutures
    with HasPredictionCol
    with HasValidationCol {

  def setValidationCol(value: String): this.type = set(validationCol, value)

  override def transformSchema(schema: StructType): StructType = {
    schema
      .add(StructField($(validationCol), new VectorUDT))
      .add(StructField(IUberdataForecastUtil.ALGORITHM, StringType))
      .add(StructField(IUberdataForecastUtil.PARAMS, MapType(StringType, StringType)))
  }
} 
Example 96
Source File: CarbonCatalystOperators.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.collection.mutable

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.StringType

abstract class CarbonProfile(attributes: Seq[Attribute]) extends Serializable {
  def isEmpty: Boolean = attributes.isEmpty
}

case class IncludeProfile(attributes: Seq[Attribute]) extends CarbonProfile(attributes)

case class ExcludeProfile(attributes: Seq[Attribute]) extends CarbonProfile(attributes)

case class ProjectForUpdate(
    table: UnresolvedRelation,
    columns: List[String],
    children: Seq[LogicalPlan]) extends LogicalPlan {
  override def output: Seq[Attribute] = Seq.empty
}

case class UpdateTable(
    table: UnresolvedRelation,
    columns: List[String],
    selectStmt: String,
    alias: Option[String] = None,
    filer: String) extends LogicalPlan {
  override def children: Seq[LogicalPlan] = Seq.empty
  override def output: Seq[Attribute] = Seq.empty
}

case class DeleteRecords(
    statement: String,
    alias: Option[String] = None,
    table: UnresolvedRelation) extends LogicalPlan {
  override def children: Seq[LogicalPlan] = Seq.empty
  override def output: Seq[AttributeReference] = Seq.empty
}


  def strictCountStar(groupingExpressions: Seq[Expression],
      partialComputation: Seq[NamedExpression],
      child: LogicalPlan): Boolean = {
    if (groupingExpressions.nonEmpty) {
      return false
    }
    if (partialComputation.isEmpty) {
      return false
    }
    if (partialComputation.size > 1 && partialComputation.nonEmpty) {
      return false
    }
    child collect {
      case cd: Filter => return false
    }
    true
  }
} 
Example 97
Source File: CarbonShowStreamsCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.stream

import java.util.Date
import java.util.concurrent.TimeUnit

import org.apache.spark.sql.{CarbonEnv, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.MetadataCommand
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.stream.StreamJobManager


case class CarbonShowStreamsCommand(
    tableOp: Option[TableIdentifier]
) extends MetadataCommand {
  override def output: Seq[Attribute] = {
    Seq(AttributeReference("Stream Name", StringType, nullable = false)(),
      AttributeReference("JobId", StringType, nullable = false)(),
      AttributeReference("Status", StringType, nullable = false)(),
      AttributeReference("Source", StringType, nullable = false)(),
      AttributeReference("Sink", StringType, nullable = false)(),
      AttributeReference("Start Time", StringType, nullable = false)(),
      AttributeReference("Time Elapse", StringType, nullable = false)())
  }

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    val jobs = tableOp match {
      case None => StreamJobManager.getAllJobs.toSeq
      case Some(table) =>
        val carbonTable = CarbonEnv.getCarbonTable(table.database, table.table)(sparkSession)
        setAuditTable(carbonTable)
        StreamJobManager.getAllJobs.filter { job =>
          job.sinkTable.equalsIgnoreCase(carbonTable.getTableName) &&
          job.sinkDb.equalsIgnoreCase(carbonTable.getDatabaseName)
        }.toSeq
    }

    jobs.map { job =>
      val elapsedTime = System.currentTimeMillis() - job.startTime
      Row(
        job.streamName,
        job.streamingQuery.id.toString,
        if (job.streamingQuery.isActive) "RUNNING" else "FAILED",
        s"${ job.sourceDb }.${ job.sourceTable }",
        s"${ job.sinkDb }.${ job.sinkTable }",
        new Date(job.startTime).toString,
        String.format(
          "%s days, %s hours, %s min, %s sec",
          TimeUnit.MILLISECONDS.toDays(elapsedTime).toString,
          TimeUnit.MILLISECONDS.toHours(elapsedTime).toString,
          TimeUnit.MILLISECONDS.toMinutes(elapsedTime).toString,
          TimeUnit.MILLISECONDS.toSeconds(elapsedTime).toString)
      )
    }
  }

  override protected def opName: String = "SHOW STREAMS"
} 
Example 98
Source File: CarbonShowMVCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.view

import java.util

import scala.collection.JavaConverters._

import org.apache.spark.sql.{CarbonEnv, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.{Checker, DataCommand}
import org.apache.spark.sql.types.{BooleanType, StringType}

import org.apache.carbondata.core.view.{MVProperty, MVSchema}
import org.apache.carbondata.view.MVManagerInSpark


case class CarbonShowMVCommand(
    databaseNameOption: Option[String],
    relatedTableIdentifier: Option[TableIdentifier]) extends DataCommand {

  override def output: Seq[Attribute] = {
    Seq(
      AttributeReference("Database", StringType, nullable = false)(),
      AttributeReference("Name", StringType, nullable = false)(),
      AttributeReference("Status", StringType, nullable = false)(),
      AttributeReference("Refresh Mode", StringType, nullable = false)(),
      AttributeReference("Refresh Trigger Mode", StringType, nullable = false)(),
      AttributeReference("Properties", StringType, nullable = false)())
  }

  override def processData(session: SparkSession): Seq[Row] = {
    // Get mv schemas.
    val schemaList = new util.ArrayList[MVSchema]()
    val viewManager = MVManagerInSpark.get(session)
    relatedTableIdentifier match {
      case Some(table) =>
        val relatedTable = CarbonEnv.getCarbonTable(table)(session)
        setAuditTable(relatedTable)
        Checker.validateTableExists(table.database, table.table, session)
        if (databaseNameOption.isDefined) {
          schemaList.addAll(viewManager.getSchemasOnTable(
            databaseNameOption.get,
            relatedTable))
        } else {
          schemaList.addAll(viewManager.getSchemasOnTable(relatedTable))
        }
      case _ =>
        if (databaseNameOption.isDefined) {
          schemaList.addAll(viewManager.getSchemas(databaseNameOption.get))
        } else {
          schemaList.addAll(viewManager.getSchemas())
        }
    }
    // Convert mv schema to row.
    schemaList.asScala.map {
      schema =>
        Row(
          schema.getIdentifier.getDatabaseName,
          schema.getIdentifier.getTableName,
          schema.getStatus.name(),
          schema.getProperties.get(MVProperty.REFRESH_MODE),
          schema.getProperties.get(MVProperty.REFRESH_TRIGGER_MODE),
          schema.getPropertiesAsString
        )
    }
  }

  override protected def opName: String = "SHOW MATERIALIZED VIEW"
} 
Example 99
Source File: CarbonCliCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.management

import java.util

import scala.collection.JavaConverters._

import org.apache.spark.sql.{CarbonEnv, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.{Checker, DataCommand}
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.tool.CarbonCli


case class CarbonCliCommand(
    databaseNameOp: Option[String],
    tableName: String,
    commandOptions: String)
  extends DataCommand {

  override def output: Seq[Attribute] = {
      Seq(AttributeReference("CarbonCli", StringType, nullable = false)())
  }

  override def processData(sparkSession: SparkSession): Seq[Row] = {
    Checker.validateTableExists(databaseNameOp, tableName, sparkSession)
    val carbonTable = CarbonEnv.getCarbonTable(databaseNameOp, tableName)(sparkSession)
    setAuditTable(carbonTable)
    setAuditInfo(Map("options" -> commandOptions))
    val commandArgs: Seq[String] = commandOptions.split("\\s+").map(_.trim)
    val finalCommands = commandArgs.exists(_.equalsIgnoreCase("-p")) match {
      case true =>
        commandArgs
      case false =>
        val needPath = commandArgs.exists { command =>
          command.equalsIgnoreCase("summary") || command.equalsIgnoreCase("benchmark")
        }
        needPath match {
          case true =>
            commandArgs ++ Seq("-p", carbonTable.getTablePath)
          case false =>
            commandArgs
        }
    }
    val summaryOutput = new util.ArrayList[String]()
    CarbonCli.run(finalCommands.toArray, summaryOutput, false)
    summaryOutput.asScala.map(x =>
      Row(x)
    )
  }

  override protected def opName: String = "CLI"
} 
Example 100
Source File: CarbonShowTablesCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.table

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.MetadataCommand
import org.apache.spark.sql.types.{BooleanType, StringType}


private[sql] case class CarbonShowTablesCommand ( databaseName: Option[String],
    tableIdentifierPattern: Option[String])  extends MetadataCommand{

  // The result of SHOW TABLES has three columns: database, tableName and isTemporary.
  override val output: Seq[Attribute] = {
    AttributeReference("database", StringType, nullable = false)() ::
    AttributeReference("tableName", StringType, nullable = false)() ::
    AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil
  }

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    // Since we need to return a Seq of rows, we will call getTables directly
    // instead of calling tables in sparkSession.
    val catalog = sparkSession.sessionState.catalog
    val db = databaseName.getOrElse(catalog.getCurrentDatabase)
    val tables =
      tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db))
    val externalCatalog = sparkSession.sharedState.externalCatalog
    // this method checks whether the table is mainTable or MV based on property "isVisible"
    def isMainTable(tableIdent: TableIdentifier) = {
      var isMainTable = true
      try {
        isMainTable = externalCatalog.getTable(db, tableIdent.table).storage.properties
          .getOrElse("isVisible", true).toString.toBoolean
      } catch {
        case ex: Throwable =>
        // ignore the exception for show tables
      }
      isMainTable
    }
    // tables will be filtered for all the MVs to show only main tables
    tables.collect {
      case tableIdent if isMainTable(tableIdent) =>
        val isTemp = catalog.isTemporaryTable(tableIdent)
        Row(tableIdent.database.getOrElse("default"), tableIdent.table, isTemp)
    }

  }

  override protected def opName: String = "SHOW TABLES"
} 
Example 101
Source File: CarbonExplainCommand.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command.table

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, Union}
import org.apache.spark.sql.execution.command.{ExplainCommand, MetadataCommand}
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.core.profiler.ExplainCollector

case class CarbonExplainCommand(
    child: LogicalPlan,
    override val output: Seq[Attribute] =
    Seq(AttributeReference("plan", StringType, nullable = true)()))
  extends MetadataCommand {

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    val explainCommand = child.asInstanceOf[ExplainCommand]
    setAuditInfo(Map("query" -> explainCommand.logicalPlan.simpleString))
    val isCommand = explainCommand.logicalPlan match {
      case _: Command => true
      case Union(childern) if childern.forall(_.isInstanceOf[Command]) => true
      case _ => false
    }

    if (explainCommand.logicalPlan.isStreaming || isCommand) {
      explainCommand.run(sparkSession)
    } else {
      CarbonExplainCommand.collectProfiler(explainCommand, sparkSession) ++
      explainCommand.run(sparkSession)
    }
  }

  override protected def opName: String = "EXPLAIN"
}

case class CarbonInternalExplainCommand(
    explainCommand: ExplainCommand,
    override val output: Seq[Attribute] =
    Seq(AttributeReference("plan", StringType, nullable = true)()))
  extends MetadataCommand {

  override def processMetadata(sparkSession: SparkSession): Seq[Row] = {
    CarbonExplainCommand
      .collectProfiler(explainCommand, sparkSession) ++ explainCommand.run(sparkSession)
  }

  override protected def opName: String = "Carbon EXPLAIN"
}

object CarbonExplainCommand {
  def collectProfiler(
      explain: ExplainCommand,
      sparkSession: SparkSession): Seq[Row] = {
    try {
      ExplainCollector.setup()
      if (ExplainCollector.enabled()) {
        val queryExecution =
          sparkSession.sessionState.executePlan(explain.logicalPlan)
        queryExecution.toRdd.partitions
        // For count(*) queries the explain collector will be disabled, so profiler
        // informations not required in such scenarios.
        if (null == ExplainCollector.getFormatedOutput) {
          Seq.empty
        }
        Seq(Row("== CarbonData Profiler ==\n" + ExplainCollector.getFormatedOutput))
      } else {
        Seq.empty
      }
    } finally {
      ExplainCollector.remove()
    }
  }
} 
Example 102
Source File: CarbonUDFTransformRule.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, PredicateHelper,
ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.StringType

import org.apache.carbondata.core.constants.CarbonCommonConstants

class CarbonUDFTransformRule extends Rule[LogicalPlan] with PredicateHelper {
  override def apply(plan: LogicalPlan): LogicalPlan = {
      pushDownUDFToJoinLeftRelation(plan)
  }

  private def pushDownUDFToJoinLeftRelation(plan: LogicalPlan): LogicalPlan = {
    val output = plan.transform {
      case proj@Project(cols, Join(
      left, right, jointype: org.apache.spark.sql.catalyst.plans.JoinType, condition)) =>
        var projectionToBeAdded: Seq[org.apache.spark.sql.catalyst.expressions.Alias] = Seq.empty
        var udfExists = false
        val newCols = cols.map {
          case a@Alias(s: ScalaUDF, name)
            if name.equalsIgnoreCase(CarbonCommonConstants.POSITION_ID) ||
               name.equalsIgnoreCase(CarbonCommonConstants.CARBON_IMPLICIT_COLUMN_TUPLEID) =>
            udfExists = true
            projectionToBeAdded :+= a
            AttributeReference(name, StringType, nullable = true)().withExprId(a.exprId)
          case other => other
        }
        if (udfExists) {
          val newLeft = left match {
            case Project(columns, logicalPlan) =>
              Project(columns ++ projectionToBeAdded, logicalPlan)
            case filter: Filter =>
              Project(filter.output ++ projectionToBeAdded, filter)
            case relation: LogicalRelation =>
              Project(relation.output ++ projectionToBeAdded, relation)
            case other => other
          }
          Project(newCols, Join(newLeft, right, jointype, condition))
        } else {
          proj
        }
      case other => other
    }
    output
  }

} 
Example 103
Source File: CarbonDataFrameExample.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.examples

import org.apache.spark.sql.{SaveMode, SparkSession}

import org.apache.carbondata.examples.util.ExampleUtils

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createSparkSession("CarbonDataFrameExample")
    exampleBody(spark)
    spark.close()
  }

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
    df.write
      .format("carbondata")
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names
      .mode(SaveMode.Overwrite)
      .save()

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf = spark.read
      .format("carbondata")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")
      .load()

    // Dataframe operations
    carbondf.printSchema()
    carbondf.select($"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
  }
} 
Example 104
Source File: NestedTableExample.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object NestedTableExample {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)
    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 105
Source File: PopulateHiveTable.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}


object PopulateHiveTable {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .config("spark.sql.parquet.compression.codec", "gzip")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    populated1Df.repartition(2).write.saveAsTable("nested_populated")

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 106
Source File: OutputMetricsTest.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.types.{IntegerType, StringType}

class OutputMetricsTest extends IntegrationSuiteBase {
  it("records written") {
    var outputWritten = 0L
    spark.sparkContext.addSparkListener(new SparkListener() {
      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
        val metrics = taskEnd.taskMetrics
        outputWritten += metrics.outputMetrics.recordsWritten
      }
    })

    val numRows = 100000
    val df1 = spark.createDF(
      List.range(0, numRows),
      List(("id", IntegerType, true))
    )

    df1.repartition(30)

    df1.write
      .format("memsql")
      .save("metricsInts")

    assert(outputWritten == numRows)
    outputWritten = 0

    val df2 = spark.createDF(
      List("st1", "", null),
      List(("st", StringType, true))
    )

    df2.write
      .format("memsql")
      .save("metricsStrings")

    assert(outputWritten == 3)
  }
} 
Example 107
Source File: SparkRecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.SparkRecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class SparkRecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 108
Source File: RecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.RecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class RecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 109
Source File: TemporalDataSuite.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.datasource.config.ConfigParameters._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.BeforeAndAfter

private[datasource] trait TemporalDataSuite extends DatasourceSuite
  with BeforeAndAfter {

  val conf = new SparkConf()
    .setAppName("datasource-receiver-example")
    .setIfMissing("spark.master", "local[*]")
  var sc: SparkContext = null
  var ssc: StreamingContext = null
  val tableName = "tableName"
  val datasourceParams = Map(
    StopGracefully -> "true",
    StopSparkContext -> "false",
    StorageLevelKey -> "MEMORY_ONLY",
    RememberDuration -> "15s"
  )
  val schema = new StructType(Array(
    StructField("id", StringType, nullable = true),
    StructField("idInt", IntegerType, nullable = true)
  ))
  val totalRegisters = 10000
  val registers = for (a <- 1 to totalRegisters) yield Row(a.toString, a)

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }
} 
Example 110
Source File: SparkEsBulkWriterSpec.scala    From Spark2Elasticsearch   with Apache License 2.0 5 votes vote down vote up
package com.github.jparkie.spark.elasticsearch

import com.github.jparkie.spark.elasticsearch.conf.{ SparkEsMapperConf, SparkEsWriteConf }
import com.github.jparkie.spark.elasticsearch.sql.{ SparkEsDataFrameMapper, SparkEsDataFrameSerializer }
import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.{ Row, SQLContext }
import org.scalatest.{ MustMatchers, WordSpec }

class SparkEsBulkWriterSpec extends WordSpec with MustMatchers with SharedSparkContext {
  val esServer = new ElasticSearchServer()

  override def beforeAll(): Unit = {
    super.beforeAll()

    esServer.start()
  }

  override def afterAll(): Unit = {
    esServer.stop()

    super.afterAll()
  }

  "SparkEsBulkWriter" must {
    "execute write() successfully" in {
      esServer.createAndWaitForIndex("test_index")

      val sqlContext = new SQLContext(sc)

      val inputSparkEsWriteConf = SparkEsWriteConf(
        bulkActions = 10,
        bulkSizeInMB = 1,
        concurrentRequests = 0,
        flushTimeoutInSeconds = 1
      )
      val inputMapperConf = SparkEsMapperConf(
        esMappingId = Some("id"),
        esMappingParent = None,
        esMappingVersion = None,
        esMappingVersionType = None,
        esMappingRouting = None,
        esMappingTTLInMillis = None,
        esMappingTimestamp = None
      )
      val inputSchema = StructType(
        Array(
          StructField("id", StringType, true),
          StructField("parent", StringType, true),
          StructField("version", LongType, true),
          StructField("routing", StringType, true),
          StructField("ttl", LongType, true),
          StructField("timestamp", StringType, true),
          StructField("value", LongType, true)
        )
      )
      val inputData = sc.parallelize {
        Array(
          Row("TEST_ID_1", "TEST_PARENT_1", 1L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 1L),
          Row("TEST_ID_1", "TEST_PARENT_2", 2L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 2L),
          Row("TEST_ID_1", "TEST_PARENT_3", 3L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 3L),
          Row("TEST_ID_1", "TEST_PARENT_4", 4L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 4L),
          Row("TEST_ID_1", "TEST_PARENT_5", 5L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 5L),
          Row("TEST_ID_5", "TEST_PARENT_6", 6L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 6L),
          Row("TEST_ID_6", "TEST_PARENT_7", 7L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 7L),
          Row("TEST_ID_7", "TEST_PARENT_8", 8L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 8L),
          Row("TEST_ID_8", "TEST_PARENT_9", 9L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 9L),
          Row("TEST_ID_9", "TEST_PARENT_10", 10L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 10L),
          Row("TEST_ID_10", "TEST_PARENT_11", 11L, "TEST_ROUTING_1", 86400000L, "TEST_TIMESTAMP_1", 11L)
        )
      }
      val inputDataFrame = sqlContext.createDataFrame(inputData, inputSchema)
      val inputDataIterator = inputDataFrame.rdd.toLocalIterator
      val inputSparkEsBulkWriter = new SparkEsBulkWriter[Row](
        esIndex = "test_index",
        esType = "test_type",
        esClient = () => esServer.client,
        sparkEsSerializer = new SparkEsDataFrameSerializer(inputSchema),
        sparkEsMapper = new SparkEsDataFrameMapper(inputMapperConf),
        sparkEsWriteConf = inputSparkEsWriteConf
      )

      inputSparkEsBulkWriter.write(null, inputDataIterator)

      val outputGetResponse = esServer.client.prepareGet("test_index", "test_type", "TEST_ID_1").get()

      outputGetResponse.isExists mustEqual true
      outputGetResponse.getSource.get("parent").asInstanceOf[String] mustEqual "TEST_PARENT_5"
      outputGetResponse.getSource.get("version").asInstanceOf[Integer] mustEqual 5
      outputGetResponse.getSource.get("routing").asInstanceOf[String] mustEqual "TEST_ROUTING_1"
      outputGetResponse.getSource.get("ttl").asInstanceOf[Integer] mustEqual 86400000
      outputGetResponse.getSource.get("timestamp").asInstanceOf[String] mustEqual "TEST_TIMESTAMP_1"
      outputGetResponse.getSource.get("value").asInstanceOf[Integer] mustEqual 5
    }
  }
} 
Example 111
Source File: NullValuesTest.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb

import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
import com.audienceproject.spark.dynamodb.implicits._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class NullValuesTest extends AbstractInMemoryTest {

    test("Insert nested StructType with null values") {
        dynamoDB.createTable(new CreateTableRequest()
            .withTableName("NullTest")
            .withAttributeDefinitions(new AttributeDefinition("name", "S"))
            .withKeySchema(new KeySchemaElement("name", "HASH"))
            .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))

        val schema = StructType(
            Seq(
                StructField("name", StringType, nullable = false),
                StructField("info", StructType(
                    Seq(
                        StructField("age", IntegerType, nullable = true),
                        StructField("address", StringType, nullable = true)
                    )
                ), nullable = true)
            )
        )

        val rows = spark.sparkContext.parallelize(Seq(
            Row("one", Row(30, "Somewhere")),
            Row("two", null),
            Row("three", Row(null, null))
        ))

        val newItemsDs = spark.createDataFrame(rows, schema)

        newItemsDs.write.dynamodb("NullTest")

        val validationDs = spark.read.dynamodb("NullTest")

        validationDs.show(false)
    }

} 
Example 112
Source File: BigQuerySource.scala    From spark-bigquery   with Apache License 2.0 5 votes vote down vote up
package com.samelamin.spark.bigquery.streaming

import java.math.BigInteger
import com.google.cloud.hadoop.io.bigquery.BigQueryStrings
import com.samelamin.spark.bigquery.BigQueryClient
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.streaming.{Offset, _}
import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType}
import com.samelamin.spark.bigquery._
import com.samelamin.spark.bigquery.converters.SchemaConverters
import org.joda.time.DateTime
import org.slf4j.LoggerFactory


  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
    val startIndex = start.getOrElse(LongOffset(0L)).asInstanceOf[LongOffset].offset.toLong
    val endIndex = end.asInstanceOf[LongOffset].offset.toLong
    val startPartitionTime = new DateTime(startIndex).toLocalDate
    val endPartitionTime = new DateTime(endIndex).toLocalDate.toString
    logger.info(s"Fetching data between $startIndex and $endIndex")
    val query =
      s"""
         |SELECT
         |  *
         |FROM
         |  `${fullyQualifiedOutputTableId.replace(':','.')}`
         |WHERE
         |  $timestampColumn BETWEEN TIMESTAMP_MILLIS($startIndex) AND TIMESTAMP_MILLIS($endIndex)
         |  AND _PARTITIONTIME BETWEEN TIMESTAMP('$startPartitionTime') AND TIMESTAMP('$endPartitionTime')
         |  """.stripMargin
    val bigQuerySQLContext = new BigQuerySQLContext(sqlContext)
    val df = bigQuerySQLContext.bigQuerySelect(query)
    df
  }

  override def stop(): Unit = {}
  def getConvertedSchema(sqlContext: SQLContext): StructType = {
    val bigqueryClient = BigQueryClient.getInstance(sqlContext)
    val tableReference = BigQueryStrings.parseTableReference(fullyQualifiedOutputTableId)
    SchemaConverters.BQToSQLSchema(bigqueryClient.getTableSchema(tableReference))
  }
}

object BigQuerySource {
  val DEFAULT_SCHEMA = StructType(
    StructField("Sample Column", StringType) ::
      StructField("value", BinaryType) :: Nil
  )
} 
Example 113
Source File: Cleaner.scala    From CkoocNLP   with Apache License 2.0 5 votes vote down vote up
package functions.clean

import com.hankcs.hanlp.HanLP
import config.paramconf.{HasOutputCol, HasInputCol}
import functions.MySchemaUtils
import functions.clean.chinese.BCConvert
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{IntParam, Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}



  setDefault(fanjan -> "f2j", quanban -> "q2b", minLineLen -> 1)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)

    val cleanFunc = udf {line: String =>
      var cleaned = ""
      getFanJian match {
        case "f2j" => cleaned = HanLP.convertToSimplifiedChinese(line)
        case "j2f" => cleaned = HanLP.convertToTraditionalChinese(line)
        case _ => cleaned = line
      }

      getQuanBan match {
        case "q2b" => cleaned = BCConvert.qj2bj(cleaned)
        case "b2q" => cleaned = BCConvert.bj2qj(cleaned)
        case _ => cleaned = cleaned
      }

      cleaned
    }

    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), cleanFunc(col($(inputCol))).as($(outputCol), metadata)).filter{record =>
      val outputIndex = record.fieldIndex($(outputCol))
      record.getString(outputIndex).length >= getMinLineLen
    }
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.typeName.equals(StringType.typeName),
      s"Input type must be StringType but got $inputType.")
    MySchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
  }
}


object Cleaner extends DefaultParamsReadable[Cleaner] {
  override def load(path: String): Cleaner = super.load(path)
} 
Example 114
Source File: ExtAggregatesSpec.scala    From spark-ext   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import com.collective.TestSparkContext
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.scalatest.FlatSpec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.ext.functions._

import scala.collection.mutable

class ExtAggregatesSpec extends FlatSpec with TestSparkContext {

  val schema = StructType(Seq(
    StructField("cookie_id", StringType),
    StructField("site", StringType),
    StructField("impressions", LongType)
  ))

  val cookie1 = "cookie1"
  val cookie2 = "cookie2"
  val cookie3 = "cookie3"

  val impressionLog = sqlContext.createDataFrame(sc.parallelize(Seq(
    Row(cookie1, "google.com", 10L),
    Row(cookie1, "cnn.com", 14L),
    Row(cookie1, "google.com", 2L),
    Row(cookie2, "bbc.com", 20L),
    Row(cookie2, "auto.com", null),
    Row(cookie2, "auto.com", 1L),
    Row(cookie3, "sport.com", 100L)
  )), schema)

  "Ext Aggregates" should "collect column values as array" in {
    val cookies = impressionLog
      .select(collectArray(col("cookie_id")))
      .first().getAs[mutable.WrappedArray[String]](0)
    assert(cookies.length == 7)
    assert(cookies.toSet.size == 3)
  }

  it should "collect distinct values as array" in {
    val distinctCookies = impressionLog.select(col("cookie_id"))
      .distinct()
      .select(collectArray(col("cookie_id")))
      .first().getAs[mutable.WrappedArray[String]](0)
    assert(distinctCookies.length == 3)
  }

  it should "collect values after group by" in {
    val result = impressionLog
      .groupBy(col("cookie_id"))
      .agg(collectArray(col("site")))

    val cookieSites = result.collect().map { case Row(cookie: String, sites: mutable.WrappedArray[_]) =>
      cookie -> sites.toSeq
    }.toMap

    assert(cookieSites(cookie1).length == 3)
    assert(cookieSites(cookie2).length == 3)
    assert(cookieSites(cookie3).length == 1)

  }

} 
Example 115
Source File: VacuumTableCommand.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package io.delta.tables.execution

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaTableIdentifier, DeltaTableUtils}
import org.apache.spark.sql.delta.commands.VacuumCommand
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types.StringType


case class VacuumTableCommand(
    path: Option[String],
    table: Option[TableIdentifier],
    horizonHours: Option[Double],
    dryRun: Boolean) extends RunnableCommand {

  override val output: Seq[Attribute] =
    Seq(AttributeReference("path", StringType, nullable = true)())

  override def run(sparkSession: SparkSession): Seq[Row] = {
    val pathToVacuum =
      if (path.nonEmpty) {
        new Path(path.get)
      } else if (table.nonEmpty) {
        DeltaTableIdentifier(sparkSession, table.get) match {
          case Some(id) if id.path.nonEmpty =>
            new Path(id.path.get)
          case _ =>
            new Path(sparkSession.sessionState.catalog.getTableMetadata(table.get).location)
        }
      } else {
        throw DeltaErrors.missingTableIdentifierException("VACUUM")
      }
    val baseDeltaPath = DeltaTableUtils.findDeltaTableRoot(sparkSession, pathToVacuum)
    if (baseDeltaPath.isDefined) {
      if (baseDeltaPath.get != pathToVacuum) {
        throw DeltaErrors.vacuumBasePathMissingException(baseDeltaPath.get)
      }
    }
    val deltaLog = DeltaLog.forTable(sparkSession, pathToVacuum)
    if (deltaLog.snapshot.version == -1) {
      throw DeltaErrors.notADeltaTableException(
        "VACUUM",
        DeltaTableIdentifier(path = Some(pathToVacuum.toString)))
    }
    VacuumCommand.gc(sparkSession, deltaLog, dryRun, horizonHours).collect()
  }
} 
Example 116
Source File: Lambda.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.SparkContext
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{ParamMap, UDFParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

object Lambda extends ComplexParamsReadable[Lambda] {
  def apply(f: Dataset[_] => DataFrame): Lambda = {
    new Lambda().setTransform(f)
  }
}

class Lambda(val uid: String) extends Transformer with Wrappable with ComplexParamsWritable {
  def this() = this(Identifiable.randomUID("Lambda"))

  val transformFunc = new UDFParam(this, "transformFunc", "holder for dataframe function")

  def setTransform(f: Dataset[_] => DataFrame): this.type = {
    set(transformFunc, udf(f, StringType))
  }

  def getTransform: Dataset[_] => DataFrame = {
    $(transformFunc).f.asInstanceOf[Dataset[_] => DataFrame]
  }

  val transformSchemaFunc = new UDFParam(this, "transformSchemaFunc", "the output schema after the transformation")

  def setTransformSchema(f: StructType => StructType): this.type = {
    set(transformSchemaFunc, udf(f, StringType))
  }

  def getTransformSchema: StructType => StructType = {
    $(transformSchemaFunc).f.asInstanceOf[StructType => StructType]
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    getTransform(dataset)
  }

  def transformSchema(schema: StructType): StructType = {
    if (get(transformSchemaFunc).isEmpty) {
      val sc = SparkContext.getOrCreate()
      val df = SparkSession.builder().getOrCreate().createDataFrame(sc.emptyRDD[Row], schema)
      transform(df).schema
    } else {
      getTransformSchema(schema)
    }
  }

  def copy(extra: ParamMap): Lambda = defaultCopy(extra)

} 
Example 117
Source File: StratifiedRepartitionSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

  import session.implicits._

  val values = "values"
  val colors = "colors"
  val const = "const"

  lazy val input = Seq(
    (0, "Blue", 2),
    (0, "Red", 2),
    (0, "Green", 2),
    (1, "Purple", 2),
    (1, "Orange", 2),
    (1, "Indigo", 2),
    (2, "Violet", 2),
    (2, "Black", 2),
    (2, "White", 2),
    (3, "Gray", 2),
    (3, "Yellow", 2),
    (3, "Cerulean", 2)
  ).toDF(values, colors, const)

  test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
    val inputSchema = new StructType()
      .add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
    val inputEnc = RowEncoder(inputSchema)
    val valuesFieldIndex = inputSchema.fieldIndex(values)
    val numPartitions = 3
    val trainData = input.repartition(numPartitions).select(values, colors, const)
      .mapPartitions(iter => {
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        // Remove all instances of 0 class on partition 1
        if (partId == 1) {
          iter.flatMap(row => {
            if (row.getInt(valuesFieldIndex) <= 0)
              None
            else Some(row)
          })
        } else {
          // Add back at least 3 instances on other partitions
          val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
          (iter.toList.union(oneOfEachExample).union(oneOfEachExample).union(oneOfEachExample)).toIterator
        }
      })(inputEnc).cache()
    // Some debug to understand what data is on which partition
    trainData.foreachPartition { rows =>
      rows.foreach { row =>
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        println(s"Row: $row partition id: $partId")
      }
    }
    val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Equal).transform(trainData)
    // Assert stratified data contains all keys across all partitions, with extra count
    // for it to be evaluated
    stratifiedInputData
      .mapPartitions(iter => {
        val actualLabels = iter.map(row => row.getInt(valuesFieldIndex))
          .toArray.distinct.sorted.toList
        val expectedLabels = (0 to 3).toList
        if (actualLabels != expectedLabels)
          throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
        iter
      })(inputEnc).count()
    val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Mixed).transform(trainData)
    assert(stratifiedMixedInputData.count() >= trainData.count())
    val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Original).transform(trainData)
    assert(stratifiedOriginalInputData.count() == trainData.count())
  }

  def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
    new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

  def reader: MLReadable[_] = StratifiedRepartition
} 
Example 118
Source File: HTTPSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.split2

import java.io.File

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.io.http.HTTPSchema.string_to_response
import org.apache.http.impl.client.HttpClientBuilder
import org.apache.spark.sql.execution.streaming.{HTTPSinkProvider, HTTPSourceProvider}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StringType

class HTTPSuite extends TestBase with HTTPTestUtils {

  test("stream from HTTP", TestBase.Extended) {
    val q1 = session.readStream.format(classOf[HTTPSourceProvider].getName)
      .option("host", host)
      .option("port", port.toString)
      .option("path", apiPath)
      .load()
      .withColumn("contentLength", col("request.entity.contentLength"))
      .withColumn("reply", string_to_response(col("contentLength").cast(StringType)))
      .writeStream
      .format(classOf[HTTPSinkProvider].getName)
      .option("name", "foo")
      .queryName("foo")
      .option("replyCol", "reply")
      .option("checkpointLocation", new File(tmpDir.toFile, "checkpoints").toString)
      .start()

    Thread.sleep(5000)
    val client = HttpClientBuilder.create().build()
    val p1 = sendJsonRequest(client, Map("foo" -> 1, "bar" -> "here"), url)
    val p2 = sendJsonRequest(client, Map("foo" -> 1, "bar" -> "heree"), url)
    val p3 = sendJsonRequest(client, Map("foo" -> 1, "bar" -> "hereee"), url)
    val p4 = sendJsonRequest(client, Map("foo" -> 1, "bar" -> "hereeee"), url)
    val posts = List(p1, p2, p3, p4)
    val correctResponses = List(27, 28, 29, 30)

    posts.zip(correctResponses).foreach { p =>
      assert(p._1 === p._2.toString)
    }
    q1.stop()
    client.close()
  }

} 
Example 119
Source File: SimpleHTTPTransformerSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.split1

import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.io.http.{HandlingUtils, JSONOutputParser, SimpleHTTPTransformer}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{StringType, StructType}

class SimpleHTTPTransformerSuite
  extends TransformerFuzzing[SimpleHTTPTransformer] with WithServer {

  import session.implicits._

  lazy val df: DataFrame = sc.parallelize((1 to 10).map(Tuple1(_))).toDF("data")

  def simpleTransformer: SimpleHTTPTransformer =
    new SimpleHTTPTransformer()
      .setInputCol("data")
      .setOutputParser(new JSONOutputParser()
        .setDataType(new StructType().add("blah", StringType)))
      .setUrl(url)
      .setOutputCol("results")

  test("HttpTransformerTest") {
    val results = simpleTransformer.transform(df).collect
    assert(results.length == 10)
    results.foreach(r =>
      assert(r.getStruct(2).getString(0) === "more blah"))
    assert(results(0).schema.fields.length == 3)
  }

  test("HttpTransformerTest with Flaky Connection") {
    lazy val df2: DataFrame = sc.parallelize((1 to 5).map(Tuple1(_))).toDF("data")
    val results = simpleTransformer
      .setUrl(url + "/flaky")
      .setTimeout(1)
      .transform(df2).collect
    assert(results.length == 5)
  }

  test("Basic Handling") {
    val results = simpleTransformer
      .setHandler(HandlingUtils.basic)
      .transform(df).collect
    assert(results.length == 10)
    results.foreach(r =>
      assert(r.getStruct(2).getString(0) === "more blah"))
    assert(results(0).schema.fields.length == 3)
  }

  test("Concurrent HttpTransformerTest") {
    val results =
      new SimpleHTTPTransformer()
        .setInputCol("data")
        .setOutputParser(new JSONOutputParser()
          .setDataType(new StructType().add("blah", StringType)))
        .setUrl(url)
        .setOutputCol("results")
        .setConcurrency(3)
        .transform(df)
        .collect
    assert(results.length == 10)
    assert(results.forall(_.getStruct(2).getString(0) == "more blah"))
    assert(results(0).schema.fields.length == 3)
  }

  override def testObjects(): Seq[TestObject[SimpleHTTPTransformer]] =
    Seq(new TestObject(simpleTransformer, df))

  override def reader: MLReadable[_] = SimpleHTTPTransformer

} 
Example 120
Source File: ParserSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.split1

import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.io.http._
import org.apache.http.client.methods.HttpPost
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, SparkSession}

trait ParserUtils extends WithServer {

  def sampleDf(spark: SparkSession): DataFrame = {
    val df = spark.createDataFrame((1 to 10).map(Tuple1(_)))
      .toDF("data")
    val df2 = new JSONInputParser().setInputCol("data")
      .setOutputCol("parsedInput").setUrl(url)
      .transform(df)
      .withColumn("unparsedOutput", udf({ x: Int =>
        HTTPResponseData(
          Array(),
          Some(EntityData(
            "{\"foo\": \"here\"}".getBytes, None, None, None, false, false, false)),
          StatusLineData(ProtocolVersionData("foo", 1, 1), 200, "bar"),
          "en")
      }).apply(col("data"))
      )

    new JSONOutputParser()
      .setDataType(new StructType().add("foo", StringType))
      .setInputCol("unparsedOutput")
      .setOutputCol("parsedOutput")
      .transform(df2)
  }

  def makeTestObject[T <: Transformer](t: T, session: SparkSession): Seq[TestObject[T]] = {
    Seq(new TestObject(t, sampleDf(session)))
  }

}

class JsonInputParserSuite extends TransformerFuzzing[JSONInputParser] with ParserUtils {
  override def testObjects(): Seq[TestObject[JSONInputParser]] = makeTestObject(
    new JSONInputParser().setInputCol("data").setOutputCol("out")
      .setUrl(url), session)

  override def reader: MLReadable[_] = JSONInputParser
}

class JsonOutputParserSuite extends TransformerFuzzing[JSONOutputParser] with ParserUtils {
  override def testObjects(): Seq[TestObject[JSONOutputParser]] = makeTestObject(
    new JSONOutputParser().setInputCol("unparsedOutput").setOutputCol("out")
      .setDataType(new StructType().add("foo", StringType)), session)

  override def reader: MLReadable[_] = JSONOutputParser
}

class StringOutputParserSuite extends TransformerFuzzing[StringOutputParser] with ParserUtils {
  override def testObjects(): Seq[TestObject[StringOutputParser]] = makeTestObject(
    new StringOutputParser().setInputCol("unparsedOutput").setOutputCol("out"), session)

  override def reader: MLReadable[_] = StringOutputParser
}

class CustomInputParserSuite extends TransformerFuzzing[CustomInputParser] with ParserUtils {
  override def testObjects(): Seq[TestObject[CustomInputParser]] = makeTestObject(
    new CustomInputParser().setInputCol("data").setOutputCol("out")
      .setUDF({ x: Int => new HttpPost(s"http://$x") }), session)

  override def reader: MLReadable[_] = CustomInputParser
}

class CustomOutputParserSuite extends TransformerFuzzing[CustomOutputParser] with ParserUtils {
  override def testObjects(): Seq[TestObject[CustomOutputParser]] = makeTestObject(
    new CustomOutputParser().setInputCol("unparsedOutput").setOutputCol("out")
      .setUDF({ x: HTTPResponseData => x.locale }), session)

  override def reader: MLReadable[_] = CustomOutputParser
} 
Example 121
Source File: Locus.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.variant

import is.hail.annotations.Annotation
import is.hail.check.Gen
import is.hail.expr.Parser
import is.hail.utils._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.json4s._

import scala.collection.JavaConverters._
import scala.language.implicitConversions

object Locus {
  val simpleContigs: Seq[String] = (1 to 22).map(_.toString) ++ Seq("X", "Y", "MT")

  def apply(contig: String, position: Int, rg: ReferenceGenome): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)
  }

  def annotation(contig: String, position: Int, rg: Option[ReferenceGenome]): Annotation = {
    rg match {
      case Some(ref) => Locus(contig, position, ref)
      case None => Annotation(contig, position)
    }
  }

  def sparkSchema: StructType =
    StructType(Array(
      StructField("contig", StringType, nullable = false),
      StructField("position", IntegerType, nullable = false)))

  def fromRow(r: Row): Locus = {
    Locus(r.getAs[String](0), r.getInt(1))
  }

  def gen(rg: ReferenceGenome): Gen[Locus] = for {
    (contig, length) <- Contig.gen(rg)
    pos <- Gen.choose(1, length)
  } yield Locus(contig, pos)

  def parse(str: String, rg: ReferenceGenome): Locus = {
    val elts = str.split(":")
    val size = elts.length
    if (size < 2)
      fatal(s"Invalid string for Locus. Expecting contig:pos -- found '$str'.")

    val contig = elts.take(size - 1).mkString(":")
    Locus(contig, elts(size - 1).toInt, rg)
  }

  def parseInterval(str: String, rg: ReferenceGenome, invalidMissing: Boolean = false): Interval =
    Parser.parseLocusInterval(str, rg, invalidMissing)

  def parseIntervals(arr: Array[String], rg: ReferenceGenome, invalidMissing: Boolean): Array[Interval] = arr.map(parseInterval(_, rg, invalidMissing))

  def parseIntervals(arr: java.util.List[String], rg: ReferenceGenome, invalidMissing: Boolean = false): Array[Interval] = parseIntervals(arr.asScala.toArray, rg, invalidMissing)

  def makeInterval(contig: String, start: Int, end: Int, includesStart: Boolean, includesEnd: Boolean,
    rgBase: ReferenceGenome, invalidMissing: Boolean = false): Interval = {
    val rg = rgBase.asInstanceOf[ReferenceGenome]
    rg.toLocusInterval(Interval(Locus(contig, start), Locus(contig, end), includesStart, includesEnd), invalidMissing)
  }
}

case class Locus(contig: String, position: Int) {
  def toRow: Row = Row(contig, position)

  def toJSON: JValue = JObject(
    ("contig", JString(contig)),
    ("position", JInt(position)))

  def copyChecked(rg: ReferenceGenome, contig: String = contig, position: Int = position): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)
  }

  def isAutosomalOrPseudoAutosomal(rg: ReferenceGenome): Boolean = isAutosomal(rg) || inXPar(rg) || inYPar(rg)

  def isAutosomal(rg: ReferenceGenome): Boolean = !(inX(rg) || inY(rg) || isMitochondrial(rg))

  def isMitochondrial(rg: ReferenceGenome): Boolean = rg.isMitochondrial(contig)

  def inXPar(rg: ReferenceGenome): Boolean = rg.inXPar(this)

  def inYPar(rg: ReferenceGenome): Boolean = rg.inYPar(this)

  def inXNonPar(rg: ReferenceGenome): Boolean = inX(rg) && !inXPar(rg)

  def inYNonPar(rg: ReferenceGenome): Boolean = inY(rg) && !inYPar(rg)

  private def inX(rg: ReferenceGenome): Boolean = rg.inX(contig)

  private def inY(rg: ReferenceGenome): Boolean = rg.inY(contig)

  override def toString: String = s"$contig:$position"
} 
Example 122
Source File: ConcatColumnBenchmark.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.jmh

import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{DataFrame, SparkSession, functions}
import org.apache.spark.storage.StorageLevel
import org.opencypher.morpheus.impl.MorpheusFunctions
import org.opencypher.morpheus.impl.expressions.EncodeLong._
import org.openjdk.jmh.annotations._

@State(Scope.Benchmark)
@BenchmarkMode(Array(Mode.AverageTime))
class ConcatColumnBenchmark {

  implicit var sparkSession: SparkSession = _

  var df: DataFrame = _

  @Setup
  def setUp(): Unit = {
    sparkSession = SparkSession.builder().master("local[*]").getOrCreate()
    val fromRow = 100000000L
    val numRows = 1000000
    val rangeDf = sparkSession.range(fromRow, fromRow + numRows).toDF("i")
    val indexCol = rangeDf.col("i")
    df = rangeDf
      .withColumn("s", indexCol.cast(StringType))
      .withColumn("b", indexCol.encodeLongAsMorpheusId)
      .partitionAndCache
  }

  @Benchmark
  def concatWs(): Int = {
    val result = df.withColumn("c", functions.concat_ws("|", df.col("i"), df.col("s"), df.col("b")))
    result.select("c").collect().length
  }

  @Benchmark
  def serialize(): Int = {
    val result = df.withColumn("c", MorpheusFunctions.serialize(df.col("i"), df.col("s"), df.col("b")))
    result.select("c").collect().length
  }

  implicit class DataFrameSetup(df: DataFrame) {

    def partitionAndCache: DataFrame = {
      val cached = df.repartition(10).persist(StorageLevel.MEMORY_ONLY)
      cached.count()
      cached
    }
  }

} 
Example 123
Source File: LoadInteractionsInHive.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.util

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.opencypher.morpheus.api.MorpheusSession

object LoadInteractionsInHive {

  val databaseName = "customers"
  val baseTableName = s"$databaseName.csv_input"

  def load(show: Boolean = false)(implicit session: MorpheusSession): DataFrame = {

    val datafile = getClass.getResource("/customer-interactions/csv/customer-interactions.csv").toURI.getPath
    val structType = StructType(Seq(
      StructField("interactionId", LongType, nullable = false),
      StructField("date", StringType, nullable = false),
      StructField("customerIdx", LongType, nullable = false),
      StructField("empNo", LongType, nullable = false),
      StructField("empName", StringType, nullable = false),
      StructField("type", StringType, nullable = false),
      StructField("outcomeScore", StringType, nullable = false),
      StructField("accountHolderId", StringType, nullable = false),
      StructField("policyAccountNumber", StringType, nullable = false),
      StructField("customerId", StringType, nullable = false),
      StructField("customerName", StringType, nullable = false)
    ))

    val baseTable: DataFrame = session.sparkSession.read
      .format("csv")
      .option("header", "true")
      .schema(structType)
      .load(datafile)

    if (show) baseTable.show()

    session.sql(s"DROP DATABASE IF EXISTS $databaseName CASCADE")
    session.sql(s"CREATE DATABASE $databaseName")

    baseTable.write.saveAsTable(s"$baseTableName")

    // Create views for nodes
    createView(baseTableName, "interactions", true, "interactionId", "date", "type", "outcomeScore")
    createView(baseTableName, "customers", true, "customerIdx", "customerId", "customerName")
    createView(baseTableName, "account_holders", true, "accountHolderId")
    createView(baseTableName, "policies", true, "policyAccountNumber")
    createView(baseTableName, "customer_reps", true, "empNo", "empName")

    // Create views for relationships
    createView(baseTableName, "has_customer_reps", false, "interactionId", "empNo")
    createView(baseTableName, "has_customers", false, "interactionId", "customerIdx")
    createView(baseTableName, "has_policies", false, "interactionId", "policyAccountNumber")
    createView(baseTableName, "has_account_holders", false, "interactionId", "accountHolderId")

    baseTable
  }

  def createView(fromTable: String, viewName: String, distinct: Boolean, columns: String*)
    (implicit session: MorpheusSession): Unit = {
    val distinctString = if (distinct) "DISTINCT" else ""

    session.sql(
      s"""
         |CREATE VIEW $databaseName.${viewName}_SEED AS
         | SELECT $distinctString ${columns.mkString(", ")}
         | FROM $fromTable
         | WHERE date < '2017-01-01'
      """.stripMargin)

    session.sql(
      s"""
         |CREATE VIEW $databaseName.${viewName}_DELTA AS
         | SELECT $distinctString ${columns.mkString(", ")}
         | FROM $fromTable
         | WHERE date >= '2017-01-01'
      """.stripMargin)
  }

} 
Example 124
Source File: MorpheusRecordHeaderTest.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl.table

import org.apache.spark.sql.types.{ArrayType, StringType, StructField}
import org.opencypher.morpheus.impl.convert.SparkConversions._
import org.opencypher.okapi.api.types.{CTList, CTString}
import org.opencypher.okapi.ir.api.expr.Var
import org.opencypher.okapi.relational.impl.table.RecordHeader
import org.opencypher.okapi.testing.BaseTestSuite

class MorpheusRecordHeaderTest extends BaseTestSuite {

  it("computes a struct type from a given record header") {
    val header = RecordHeader.empty
      .withExpr(Var("a")(CTString))
      .withExpr(Var("b")(CTString.nullable))
      .withExpr(Var("c")(CTList(CTString.nullable)))

    header.toStructType.fields.toSet should equal(Set(
      StructField(header.column(Var("a")()), StringType, nullable = false),
      StructField(header.column(Var("b")()), StringType, nullable = true),
      StructField(header.column(Var("c")()), ArrayType(StringType, containsNull = true), nullable = false)
    ))
  }

} 
Example 125
Source File: NGram.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  @Since("1.5.0")
  def getN: Int = $(n)

  setDefault(n -> 2)

  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

@Since("1.6.0")
object NGram extends DefaultParamsReadable[NGram] {

  @Since("1.6.0")
  override def load(path: String): NGram = super.load(path)
} 
Example 126
Source File: MapDataSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 127
Source File: ScalaUDFSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 128
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 129
Source File: resources.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 130
Source File: WholeStageCodegenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 131
Source File: GroupedIteratorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 132
Source File: DDLSourceLoadSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 133
Source File: Tokenizer.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  def getPattern: String = $(pattern)

  setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")

  override protected def createTransformFunc: String => Seq[String] = { str =>
    val re = $(pattern).r
    val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
    val minLength = $(minTokenLength)
    tokens.filter(_.length >= minLength)
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType, s"Input type must be string type but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)

  override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
} 
Example 134
Source File: GenerateOrdering.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BinaryType, StringType, NumericType}


object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
  import scala.reflect.runtime.{universe => ru}
  import scala.reflect.runtime.universe._

 protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
    in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

  protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
    in.map(BindReferences.bindReference(_, inputSchema))

  protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
    val a = newTermName("a")
    val b = newTermName("b")
    val comparisons = ordering.zipWithIndex.map { case (order, i) =>
      val evalA = expressionEvaluator(order.child)
      val evalB = expressionEvaluator(order.child)

      val compare = order.child.dataType match {
        case BinaryType =>
          q"""
          val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm}
          val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm}
          var i = 0
          while (i < x.length && i < y.length) {
            val res = x(i).compareTo(y(i))
            if (res != 0) return res
            i = i+1
          }
          return x.length - y.length
          """
        case _: NumericType =>
          q"""
          val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
          if(comp != 0) {
            return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
          }
          """
        case StringType =>
          if (order.direction == Ascending) {
            q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
          } else {
            q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
          }
      }

      q"""
        i = $a
        ..${evalA.code}
        i = $b
        ..${evalB.code}
        if (${evalA.nullTerm} && ${evalB.nullTerm}) {
          // Nothing
        } else if (${evalA.nullTerm}) {
          return ${if (order.direction == Ascending) q"-1" else q"1"}
        } else if (${evalB.nullTerm}) {
          return ${if (order.direction == Ascending) q"1" else q"-1"}
        } else {
          $compare
        }
      """
    }

    val q"class $orderingName extends $orderingType { ..$body }" = reify {
      class SpecificOrdering extends Ordering[Row] {
        val o = ordering
      }
    }.tree.children.head

    val code = q"""
      class $orderingName extends $orderingType {
        ..$body
        def compare(a: $rowType, b: $rowType): Int = {
          var i: $rowType = null // Holds current row being evaluated.
          ..$comparisons
          return 0
        }
      }
      new $orderingName()
      """
    logDebug(s"Generated Ordering: $code")
    toolBox.eval(code).asInstanceOf[Ordering[Row]]
  }
} 
Example 135
Source File: SparkSQLParser.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.StringType



private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {

  // A parser for the key-value part of the "SET [key = [value ]]" syntax
  private object SetCommandParser extends RegexParsers {
    private val key: Parser[String] = "(?m)[^=]+".r

    private val value: Parser[String] = "(?m).*$".r

    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())

    private val pair: Parser[LogicalPlan] =
      (key ~ ("=".r ~> value).?).? ^^ {
        case None => SetCommand(None, output)
        case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output)
      }

    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
      case Success(plan, _) => plan
      case x => sys.error(x.toString)
    }
  }

  protected val AS = Keyword("AS")
  protected val CACHE = Keyword("CACHE")
  protected val CLEAR = Keyword("CLEAR")
  protected val IN = Keyword("IN")
  protected val LAZY = Keyword("LAZY")
  protected val SET = Keyword("SET")
  protected val SHOW = Keyword("SHOW")
  protected val TABLE = Keyword("TABLE")
  protected val TABLES = Keyword("TABLES")
  protected val UNCACHE = Keyword("UNCACHE")

  override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others

  private lazy val cache: Parser[LogicalPlan] =
    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
      case isLazy ~ tableName ~ plan =>
        CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
    }

  private lazy val uncache: Parser[LogicalPlan] =
    ( UNCACHE ~ TABLE ~> ident ^^ {
        case tableName => UncacheTableCommand(tableName)
      }
    | CLEAR ~ CACHE ^^^ ClearCacheCommand
    )

  private lazy val set: Parser[LogicalPlan] =
    SET ~> restInput ^^ {
      case input => SetCommandParser(input)
    }

  private lazy val show: Parser[LogicalPlan] =
    SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
      case _ ~ dbName => ShowTablesCommand(dbName)
    }

  private lazy val others: Parser[LogicalPlan] =
    wholeInput ^^ {
      case input => fallback(input)
    }

} 
Example 136
Source File: ListTablesSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfter {

  import org.apache.spark.sql.test.TestSQLContext.implicits._

  val df =
    sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
  }

  test("get all tables") {
    checkAnswer(
      tables().filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("getting all Tables with a database name has no impact on returned table names") {
    checkAnswer(
      tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        dropTempTable("tables")
    }
  }
} 
Example 137
Source File: DiscreteDistributionBuilder.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.dataframe.report.distribution.discrete

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BooleanType, StringType, StructField}

import ai.deepsense.deeplang.doperables.dataframe.report.DataFrameReportGenerator
import ai.deepsense.deeplang.doperables.dataframe.report.distribution.{DistributionBuilder, NoDistributionReasons}
import ai.deepsense.deeplang.doperables.report.ReportUtils
import ai.deepsense.deeplang.utils.aggregators.Aggregator
import ai.deepsense.deeplang.utils.aggregators.AggregatorBatch.BatchedResult
import ai.deepsense.reportlib.model.{DiscreteDistribution, Distribution, NoDistribution}

case class DiscreteDistributionBuilder(
    categories: Aggregator[Option[scala.collection.mutable.Map[String, Long]], Row],
    missing: Aggregator[Long, Row],
    field: StructField)
  extends DistributionBuilder {

  def allAggregators: Seq[Aggregator[_, Row]] = Seq(categories, missing)

  override def build(results: BatchedResult): Distribution = {
    val categoriesMap = results.forAggregator(categories)
    val nullsCount = results.forAggregator(missing)

    categoriesMap match {
      case Some(occurrencesMap) => {
        val labels = field.dataType match {
          case StringType => occurrencesMap.keys.toSeq.sorted
          // We always want two labels, even when all elements are true or false
          case BooleanType => Seq(false.toString, true.toString)
        }
        val counts = labels.map(occurrencesMap.getOrElse(_, 0L))
        DiscreteDistribution(
          field.name,
          s"Discrete distribution for ${field.name} column",
          nullsCount,
          labels.map(ReportUtils.shortenLongStrings(_,
            DataFrameReportGenerator.StringPreviewMaxLength)),
          counts)
      }
      case None => NoDistribution(
        field.name,
        NoDistributionReasons.TooManyDistinctCategoricalValues
      )
    }
  }
} 
Example 138
Source File: CountVectorizerExample.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperations.examples

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}

import ai.deepsense.deeplang.doperables.dataframe.{DataFrame, DataFrameBuilder}
import ai.deepsense.deeplang.doperations.spark.wrappers.estimators.CountVectorizer

class CountVectorizerExample extends AbstractOperationExample[CountVectorizer]{
  override def dOperation: CountVectorizer = {
    val op = new CountVectorizer()
    op.estimator
      .setInputColumn("lines")
      .setNoInPlace("lines_out")
      .setMinTF(3)
    op.set(op.estimator.extractParamMap())
  }

  override def inputDataFrames: Seq[DataFrame] = {
    val rows = Seq(
      Row("a a a b b c c c d ".split(" ").toSeq),
      Row("c c c c c c".split(" ").toSeq),
      Row("a".split(" ").toSeq),
      Row("e e e e e".split(" ").toSeq))
    val rdd = sparkContext.parallelize(rows)
    val schema = StructType(Seq(StructField("lines", ArrayType(StringType, containsNull = true))))
    Seq(DataFrameBuilder(sparkSQLSession).buildDataFrame(schema, rdd))
  }
} 
Example 139
Source File: UnionIntegSpec.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperations

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import ai.deepsense.deeplang.doperables.dataframe.DataFrame
import ai.deepsense.deeplang.doperations.exceptions.SchemaMismatchException
import ai.deepsense.deeplang.inference.{InferContext, InferenceWarnings}
import ai.deepsense.deeplang.{DKnowledge, DeeplangIntegTestSupport}

class UnionIntegSpec extends DeeplangIntegTestSupport {

  import DeeplangIntegTestSupport._
  val schema1 = StructType(List(
    StructField("column1", DoubleType),
    StructField("column2", DoubleType)))

  val rows1_1 = Seq(
    Row(1.0, 2.0),
    Row(2.0, 3.0)
  )

  "Union" should {
    "return a union of two DataFrames" in {
      val rows1_2 = Seq(
        Row(2.0, 4.0),
        Row(4.0, 6.0)
      )

      val df1 = createDataFrame(rows1_1, schema1)
      val df2 = createDataFrame(rows1_2, schema1)

      val merged = Union()
        .executeUntyped(Vector(df1, df2))(executionContext)
        .head.asInstanceOf[DataFrame]

      assertDataFramesEqual(
        merged, createDataFrame(rows1_1 ++ rows1_2, schema1))
    }

    "throw for mismatching types in DataFrames" in {
      val schema2 = StructType(List(
        StructField("column1", StringType),
        StructField("column2", DoubleType)))

      val rows2_1 = Seq(
        Row("a", 1.0),
        Row("b", 1.0)
      )

      val df1 = createDataFrame(rows1_1, schema1)
      val df2 = createDataFrame(rows2_1, schema2)

      a [SchemaMismatchException] should be thrownBy {
        Union().executeUntyped(Vector(df1, df2))(executionContext)
      }
    }

    "throw for mismatching column names in DataFrames" in {
      val schema2 = StructType(List(
        StructField("column1", DoubleType),
        StructField("different_column_name", DoubleType)))

      val rows2_1 = Seq(
        Row(1.1, 1.0),
        Row(1.1, 1.0)
      )

      val df1 = createDataFrame(rows1_1, schema1)
      val df2 = createDataFrame(rows2_1, schema2)

      a [SchemaMismatchException] should be thrownBy {
        Union().executeUntyped(Vector(df1, df2))(executionContext)
      }
    }
  }

  it should {
    "propagate schema when both schemas match" in {
      val structType = StructType(Seq(
        StructField("x", DoubleType),
        StructField("y", DoubleType)))
      val knowledgeDF1 = DKnowledge(DataFrame.forInference(structType))
      val knowledgeDF2 = DKnowledge(DataFrame.forInference(structType))
      Union().inferKnowledgeUntyped(Vector(knowledgeDF1, knowledgeDF2))(mock[InferContext]) shouldBe
        (Vector(knowledgeDF1), InferenceWarnings())
    }
    "generate error when schemas don't match" in {
      val structType1 = StructType(Seq(
        StructField("x", DoubleType)))
      val structType2 = StructType(Seq(
        StructField("y", DoubleType)))
      val knowledgeDF1 = DKnowledge(DataFrame.forInference(structType1))
      val knowledgeDF2 = DKnowledge(DataFrame.forInference(structType2))
      an [SchemaMismatchException] shouldBe thrownBy(
        Union().inferKnowledgeUntyped(Vector(knowledgeDF1, knowledgeDF2))(mock[InferContext]))
    }
  }
} 
Example 140
Source File: AbstractEvaluatorSmokeTest.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import ai.deepsense.deeplang.doperables.dataframe.DataFrame
import ai.deepsense.deeplang.params.ParamPair
import ai.deepsense.deeplang.{DKnowledge, DeeplangIntegTestSupport}
import ai.deepsense.sparkutils.Linalg.Vectors

abstract class AbstractEvaluatorSmokeTest extends DeeplangIntegTestSupport {

  def className: String

  val evaluator: Evaluator

  val evaluatorParams: Seq[ParamPair[_]]

  val inputDataFrameSchema = StructType(Seq(
    StructField("s", StringType),
    StructField("prediction", DoubleType),
    StructField("rawPrediction", new ai.deepsense.sparkutils.Linalg.VectorUDT),
    StructField("label", DoubleType)
  ))

  val inputDataFrame: DataFrame = {
    val rowSeq = Seq(
      Row("aAa bBb cCc dDd eEe f", 1.0, Vectors.dense(2.1, 2.2, 2.3), 3.0),
      Row("das99213 99721 8i!#@!", 4.0, Vectors.dense(5.1, 5.2, 5.3), 6.0)
    )
    createDataFrame(rowSeq, inputDataFrameSchema)
  }

  def setUpStubs(): Unit = ()

  className should {
    "successfully run _evaluate()" in {
      setUpStubs()
      evaluator.set(evaluatorParams: _*)._evaluate(executionContext, inputDataFrame)
    }
    "successfully run _infer()" in {
      evaluator.set(evaluatorParams: _*)._infer(DKnowledge(inputDataFrame))
    }
    "successfully run report" in {
      evaluator.set(evaluatorParams: _*).report()
    }
  }
} 
Example 141
Source File: StringTokenizerSmokeTest.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.transformers

import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

import ai.deepsense.deeplang.doperables.multicolumn.MultiColumnParams.SingleOrMultiColumnChoices.SingleColumnChoice
import ai.deepsense.deeplang.doperables.multicolumn.SingleColumnParams.SingleTransformInPlaceChoices.NoInPlaceChoice
import ai.deepsense.deeplang.params.selections.NameSingleColumnSelection

class StringTokenizerSmokeTest
  extends AbstractTransformerWrapperSmokeTest[StringTokenizer]
  with MultiColumnTransformerWrapperTestSupport {

  override def transformerWithParams: StringTokenizer = {
     val inPlace = NoInPlaceChoice()
      .setOutputColumn("tokenized")

    val single = SingleColumnChoice()
      .setInputColumn(NameSingleColumnSelection("s"))
      .setInPlace(inPlace)

    val transformer = new StringTokenizer()
    transformer.set(Seq(
      transformer.singleOrMultiChoiceParam -> single
    ): _*)
  }

  override def testValues: Seq[(Any, Any)] = {
    val strings = Seq(
      "this is a test",
      "this values should be separated",
      "Bla bla bla!"
    )

    val tokenized = strings.map { _.toLowerCase.split("\\s") }
    strings.zip(tokenized)
  }

  override def inputType: DataType = StringType

  override def outputType: DataType = new ArrayType(StringType, true)
} 
Example 142
Source File: RegexTokenizerSmokeTest.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.transformers

import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

import ai.deepsense.deeplang.doperables.multicolumn.MultiColumnParams.SingleOrMultiColumnChoices.SingleColumnChoice
import ai.deepsense.deeplang.doperables.multicolumn.SingleColumnParams.SingleTransformInPlaceChoices.NoInPlaceChoice
import ai.deepsense.deeplang.params.selections.NameSingleColumnSelection

class RegexTokenizerSmokeTest
  extends AbstractTransformerWrapperSmokeTest[RegexTokenizer]
  with MultiColumnTransformerWrapperTestSupport {

  override def transformerWithParams: RegexTokenizer = {
    val inPlace = NoInPlaceChoice()
      .setOutputColumn("tokenized")

    val single = SingleColumnChoice()
      .setInputColumn(NameSingleColumnSelection("s"))
      .setInPlace(inPlace)

    val transformer = new RegexTokenizer()
    transformer.set(Seq(
      transformer.singleOrMultiChoiceParam -> single,
      transformer.gaps -> false,
      transformer.minTokenLength -> 1,
      transformer.pattern -> "\\d+"
    ): _*)
  }

  override def testValues: Seq[(Any, Any)] = {
    val strings = Seq(
      "100 200 300",
      "400 500 600",
      "700 800 900"
    )

    val tokenized = strings.map { _.toLowerCase.split(" ") }
    strings.zip(tokenized)
  }

  override def inputType: DataType = StringType

  override def outputType: DataType = new ArrayType(StringType, true)
} 
Example 143
Source File: NGramTransformerSmokeTest.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.transformers

import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

import ai.deepsense.deeplang.doperables.multicolumn.MultiColumnParams.SingleOrMultiColumnChoices.SingleColumnChoice
import ai.deepsense.deeplang.doperables.multicolumn.SingleColumnParams.SingleTransformInPlaceChoices.NoInPlaceChoice
import ai.deepsense.deeplang.params.selections.NameSingleColumnSelection

class NGramTransformerSmokeTest
  extends AbstractTransformerWrapperSmokeTest[NGramTransformer]
  with MultiColumnTransformerWrapperTestSupport {

  override def transformerWithParams: NGramTransformer = {
    val inPlace = NoInPlaceChoice()
      .setOutputColumn("ngrams")

    val single = SingleColumnChoice()
      .setInputColumn(NameSingleColumnSelection("as"))
      .setInPlace(inPlace)

    val transformer = new NGramTransformer()
    transformer.set(Seq(
      transformer.singleOrMultiChoiceParam -> single,
      transformer.n -> 2
    ): _*)
  }

  override def testValues: Seq[(Any, Any)] = {
    val strings = Seq(
      Array("a", "b", "c"),
      Array("d", "e", "f")
    )

    val ngrams = Seq(
      Array("a b", "b c"),
      Array("d e", "e f")
    )
    strings.zip(ngrams)
  }

  override def inputType: DataType = new ArrayType(StringType, true)

  override def outputType: DataType = new ArrayType(StringType, false)
} 
Example 144
Source File: StopWordsRemoverSmokeTest.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.transformers

import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

import ai.deepsense.deeplang.doperables.multicolumn.MultiColumnParams.SingleOrMultiColumnChoices.SingleColumnChoice
import ai.deepsense.deeplang.doperables.multicolumn.SingleColumnParams.SingleTransformInPlaceChoices.NoInPlaceChoice
import ai.deepsense.deeplang.params.selections.NameSingleColumnSelection

class StopWordsRemoverSmokeTest
    extends AbstractTransformerWrapperSmokeTest[StopWordsRemover]
    with MultiColumnTransformerWrapperTestSupport  {

  override def transformerWithParams: StopWordsRemover = {
    val inPlace = NoInPlaceChoice()
      .setOutputColumn("stopWordsRemoverOutput")
    val single = SingleColumnChoice()
      .setInputColumn(NameSingleColumnSelection("as"))
      .setInPlace(inPlace)

    val stopWordsRemover = new StopWordsRemover()
    stopWordsRemover.set(
      stopWordsRemover.singleOrMultiChoiceParam -> single,
      stopWordsRemover.caseSensitive -> false)
  }

  override def testValues: Seq[(Any, Any)] = {
    val inputNumbers = Seq(Array("a", "seahorse", "The", "Horseshoe", "Crab"))
    val outputNumbers = Seq(Array("seahorse", "Horseshoe", "Crab"))
    inputNumbers.zip(outputNumbers)
  }

  override def inputType: DataType = ArrayType(StringType)

  override def outputType: DataType = ArrayType(StringType)
} 
Example 145
Source File: HBaseLoader.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.utils.io.db.readers

import it.agilelab.bigdata.DataQuality.sources.HBaseSrcConfig
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import it.nerdammer.spark.hbase._
import it.nerdammer.spark.hbase.conversion.FieldReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}


object HBaseLoader {

  private def seqReader[T]()(implicit m1: FieldReader[Option[T]]): FieldReader[Seq[Option[T]]] = new FieldReader[Seq[Option[T]]] {
    def map(data: HBaseData): Seq[Option[T]] = data.map(optArr => m1.map(Iterable(optArr))).toSeq
  }

  
  def loadToDF(conf: HBaseSrcConfig)(implicit sqlContext: SQLContext): DataFrame = {

    implicit val stringSeqReader: FieldReader[Seq[Option[String]]] = seqReader[String]()

    val header: Seq[String] = Seq("key") ++ conf.hbaseColumns
    val rdd: RDD[Row] =
      sqlContext.sparkContext
        .hbaseTable[Seq[Option[String]]](conf.table)
        .select(conf.hbaseColumns: _*)
        .map {
          case (s: Seq[Option[String]]) => Row.fromSeq(s)
        }
    val struct = StructType(header.map(StructField(_, StringType)))
    sqlContext.createDataFrame(rdd, struct)
  }

} 
Example 146
Source File: ColumnDescriptorSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package yaooqinn.kyuubi.schema

import org.apache.hive.service.cli.thrift.{TCLIServiceConstants, TTypeId}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}

class ColumnDescriptorSuite extends SparkFunSuite {

  test("Column Descriptor basic test") {
    val col1 = "a"
    val col2 = "b"
    val comments = "no comments"
    val schema = new StructType()
      .add(col1, StringType, nullable = true, comments)
      .add(col2, DecimalType(10, 9), nullable = true, "")

    val tColumnDescs =
      (0 until schema.length).map(i => ColumnDescriptor(schema(i), i)).map(_.toTColumnDesc)
    assert(tColumnDescs.head.getColumnName === col1)
    assert(tColumnDescs.head.getComment === comments)
    assert(tColumnDescs.head.getPosition === 0)
    assert(tColumnDescs.head.getTypeDesc === TypeDescriptor(StringType).toTTypeDesc)
    assert(tColumnDescs.head.getTypeDesc.getTypesSize === 1)
    assert(tColumnDescs.head.getTypeDesc.getTypes.get(0)
      .getPrimitiveEntry.getTypeQualifiers === null)
    assert(tColumnDescs.head.getTypeDesc.getTypes.get(0)
      .getPrimitiveEntry.getType === TTypeId.STRING_TYPE)

    assert(tColumnDescs(1).getColumnName === col2)
    assert(tColumnDescs(1).getComment === "")
    assert(tColumnDescs(1).getPosition === 1)
    assert(tColumnDescs(1).getTypeDesc.getTypesSize === 1)
    assert(tColumnDescs(1)
      .getTypeDesc
      .getTypes.get(0)
      .getPrimitiveEntry
      .getTypeQualifiers
      .getQualifiers
      .get(TCLIServiceConstants.PRECISION).getI32Value === 10)
    assert(tColumnDescs(1)
      .getTypeDesc
      .getTypes.get(0)
      .getPrimitiveEntry
      .getTypeQualifiers
      .getQualifiers
      .get(TCLIServiceConstants.SCALE).getI32Value === 9)
    assert(tColumnDescs(1).getTypeDesc.getTypes.get(0)
      .getPrimitiveEntry.getType === TTypeId.DECIMAL_TYPE)

  }

  test("field is null") {
    val tColumnDesc = ColumnDescriptor(null, 0).toTColumnDesc
    assert(tColumnDesc.isSetPosition)
    assert(!tColumnDesc.isSetColumnName)
    assert(!tColumnDesc.isSetTypeDesc)
    assert(!tColumnDesc.isSetComment)
  }
} 
Example 147
Source File: FilterSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.util

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StringType
import org.scalatest.{FlatSpec, Matchers}

class FilterSpec extends FlatSpec with Matchers {
  "CompiledFilters" should "filter properly depending of his type" in {
    val eq = EqualFilter(Attr("test", ""), "a")

    eq.eval("a") should be(true)
    eq.eval("b") should be(false)

    val notEq = NotFilter(EqualFilter(Attr("test", ""), "a"))

    notEq.eval("a") should be(false)
    notEq.eval("b") should be(true)

    val in = InFilter(Attr("test", ""), Array("a", "b", "c"))

    in.eval("a") should be(true)
    in.eval("b") should be(true)
    in.eval("c") should be(true)
    in.eval("d") should be(false)

    val gt = GreaterThanFilter(Attr("test", ""), 5)

    gt.eval(4) should be(false)
    gt.eval(5) should be(false)
    gt.eval(6) should be(true)

    val gte = GreaterThanOrEqualFilter(Attr("test", ""), 5)

    gte.eval(4) should be(false)
    gte.eval(5) should be(true)
    gte.eval(6) should be(true)

    val lt = LessThanFilter(Attr("test", ""), 5)

    lt.eval(4) should be(true)
    lt.eval(5) should be(false)
    lt.eval(6) should be(false)

    val lte = LessThanOrEqualFilter(Attr("test", ""), 5)

    lte.eval(4) should be(true)
    lte.eval(5) should be(true)
    lte.eval(6) should be(false)
  }

  "ColumnFilter" should "process correctly columns" in {
    // test = 'val' AND test IS NOT NULL AND test2 = 'val2' AND test3 IN ('a', 'b')
    val f = Filter.compile(And(
      And(
        And(
          EqualTo(AttributeReference("test", StringType)(), Literal("val")),
          IsNotNull(AttributeReference("test", StringType)())
        ),
        EqualTo(AttributeReference("test2", StringType)(), Literal("val2"))
      ),
      In(AttributeReference("test3", StringType)(), Seq(Literal("a"), Literal("b")))
    ))

    f.length should be(4)
    val filters = Filters(f)
    filters.matches(Seq("test"), "val") should be(true)
    filters.matches(Seq("test2"), "val") should be(false)
    filters.matches(Seq("test3"), "b") should be(true)
  }

  "ColumnFilter" should "handle correctly unsupported filters" in {
    val f = Filter.compile(StartsWith(AttributeReference("test", StringType)(), Literal("a")))

    f.length should be(0)
  }
} 
Example 148
Source File: MetadataIteratorSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.iterator

import java.nio.file.Paths
import java.util.{Properties, UUID}

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{Metadata, StringType, StructType}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import tech.sourced.engine.{BaseSparkSpec, Schema}

class JDBCQueryIteratorSpec
  extends FlatSpec with Matchers with BeforeAndAfterAll with BaseSparkSpec {
  private val tmpPath = Paths.get(
    System.getProperty("java.io.tmpdir"),
    UUID.randomUUID.toString
  )

  private val dbPath = tmpPath.resolve("test.db")

  override def beforeAll(): Unit = {
    super.beforeAll()
    tmpPath.toFile.mkdir()
    val rdd = ss.sparkContext.parallelize(Seq(
      Row("id1"),
      Row("id2"),
      Row("id3")
    ))

    val properties = new Properties()
    properties.put("driver", "org.sqlite.JDBC")
    val df = ss.createDataFrame(rdd, StructType(Seq(Schema.repositories.head)))
    df.write.jdbc(s"jdbc:sqlite:${dbPath.toString}", "repositories", properties)
  }

  override def afterAll(): Unit = {
    super.afterAll()
    FileUtils.deleteQuietly(tmpPath.toFile)
  }

  "JDBCQueryIterator" should "return all rows for the query" in {
    val iter = new JDBCQueryIterator(
      Seq(attr("id")),
      dbPath.toString,
      "SELECT id FROM repositories ORDER BY id"
    )

    // calling hasNext more than one time does not cause rows to be lost
    iter.hasNext
    iter.hasNext
    val rows = (for (row <- iter) yield row).toArray
    rows.length should be(3)
    rows(0).length should be(1)
    rows(0)(0).toString should be("id1")
    rows(1)(0).toString should be("id2")
    rows(2)(0).toString should be("id3")
  }

  private def attr(name: String): Attribute = AttributeReference(
    name, StringType, nullable = false, Metadata.empty
  )()
} 
Example 149
Source File: MultilayerPerceptronClassifierExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// $example off$
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}


    result.show(5)
    val predictionAndLabels = result.select("prediction", "label")
    //多分类评估
    val evaluator = new MulticlassClassificationEvaluator()
      .setMetricName("precision")
    //准确率 Accuracy: 0.9636363636363636
    println("Accuracy: " + evaluator.evaluate(predictionAndLabels))
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 150
Source File: GradientBoostedTreeRegressorExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    predictions.select("prediction", "label", "features").show(5)

    // Select (prediction, true label) and compute test error.
    val evaluator = new RegressionEvaluator()
      .setLabelCol("label")//标签列名
      //预测结果列名
      .setPredictionCol("prediction")
       //rmse均方根误差说明样本的离散程度
      .setMetricName("rmse")
    val rmse = evaluator.evaluate(predictions)
     //rmse均方根误差说明样本的离散程度
    println("Root Mean Squared Error (RMSE) on test data = " + rmse)

    val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel]
    println("Learned regression GBT model:\n" + gbtModel.toDebugString)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 151
Source File: LinearRegressionWithElasticNetExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.regression.LinearRegression
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    trainingSummary.residuals.show()
    //rmse均方根误差说明样本的离散程度
    //RMSE: 10.189126225286143
    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
    //R2平方系统也称判定系数,用来评估模型拟合数据的好坏
    //r2: 0.02285205756871944
    println(s"r2: ${trainingSummary.r2}")
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 152
Source File: NaiveBayesExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}
// $example off$


    predictions.show(5)

    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")//标签列名
      .setPredictionCol("prediction")//预测结果列名
      .setMetricName("precision")//准确率
    //Accuracy: 1.0
    val accuracy = evaluator.evaluate(predictions)
    println("Accuracy: " + accuracy)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 153
Source File: VectorSlicerExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import java.util.Arrays

import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.feature.VectorSlicer

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
// $example off$
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    output.show()
    println(output.select("userFeatures", "features").first())
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 154
Source File: CountVectorizerExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel}
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

     //transform()方法将DataFrame转化为另外一个DataFrame的算法
    cvm.transform(df).select("features").show()
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 155
Source File: Word2VecExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.Word2Vec

// $example off$

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    result.show()
    result.select("result").take(3).foreach(println)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 156
Source File: TokenizerExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer}
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    regexTokenized.show()
    regexTokenized.select("words", "label").take(3).foreach(println)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 157
Source File: StringIndexerExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.StringIndexer
// $example off$
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
    //fit()方法将DataFrame转化为一个Transformer的算法
    //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val indexed = indexer.fit(df).transform(df)
    indexed.show()
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 158
Source File: DCTExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.DCT
import org.apache.spark.mllib.linalg.Vectors
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

object DCTExample {
  def main(args: Array[String]): Unit = {
    
    val conf = new SparkConf().setAppName("DCTExample").setMaster("local[4]")
    val sc = new SparkContext(conf)
  
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    
   

    // $example on$
    val data = Seq(
      Vectors.dense(0.0, 1.0, -2.0, 3.0),
      Vectors.dense(-1.0, 2.0, 4.0, -7.0),
      Vectors.dense(14.0, -2.0, -5.0, 1.0))

    val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
    //离散余弦变换(DCT)
    val dct = new DCT()
      .setInputCol("features")
      .setOutputCol("featuresDCT")
      .setInverse(false)
    //transform()方法将DataFrame转化为另外一个DataFrame的算法
    val dctDf = dct.transform(df)
    
    dctDf.select("featuresDCT").show(3)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 159
Source File: RandomForestRegressorExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
// $example off$
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    predictions.select("prediction", "label", "features").show(5)

    // Select (prediction, true label) and compute test error.
    val evaluator = new RegressionEvaluator()
      .setLabelCol("label")
	//算法预测结果的存储列的名称, 默认是”prediction”
      .setPredictionCol("prediction")
      //rmse均方根误差说明样本的离散程度
      .setMetricName("rmse")
    val rmse = evaluator.evaluate(predictions)
    //Root Mean Squared Error (RMSE) on test data = 0.09854713827168428
    println("Root Mean Squared Error (RMSE) on test data = " + rmse)

    val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel]
    println("Learned regression forest model:\n" + rfModel.toDebugString)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 160
Source File: StopWordsRemoverExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.StopWordsRemover
// $example off$
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    //transform()方法将DataFrame转化为另外一个DataFrame的算法
    remover.transform(dataSet).show()
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 161
Source File: TfIdfExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}

    rescaledData.show()
    rescaledData.select("features", "label").take(3).foreach(println)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 162
Source File: VectorIndexerExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.VectorIndexer
// $example off$
import org.apache.spark.mllib.linalg.Vectors
// $example off$
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.mllib.util._

    indexedData.show()
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 163
Source File: Tokenizer.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  def getPattern: String = $(pattern)

  setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")

  override protected def createTransformFunc: String => Seq[String] = { str =>
    val re = $(pattern).r
    val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
    val minLength = $(minTokenLength)
    tokens.filter(_.length >= minLength)
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType, s"Input type must be string type but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, true)

  override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
} 
Example 164
Source File: MiscFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
} 
Example 165
Source File: ColumnPruningSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil
  }

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation('a.int, 'b.array(StringType))

    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Generate(Explode('b), false, false, None, 's.string :: Nil,
        Project('b.attr :: Nil, input)).analyze

    comparePlans(optimized, correctAnswer)
  }
  //生成Generate.join = true时的列修剪
  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

    val query =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          Project(Seq('a, 'c),
            input))).analyze

    comparePlans(optimized, correctAnswer)
  }
  //如果可能,将Generate.join转换为false
  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), false, false, None, 's.string :: Nil,
          input)).analyze

    comparePlans(optimized, correctAnswer)
  }

  // todo: add more tests for column pruning
} 
Example 166
Source File: SparkSQLParser.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.StringType



private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {

  // A parser for the key-value part of the "SET [key = [value ]]" syntax
  //用于“SET [key = [value]]”语法的键值部分的解析器
  private object SetCommandParser extends RegexParsers {
    private val key: Parser[String] = "(?m)[^=]+".r

    private val value: Parser[String] = "(?m).*$".r

    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())

    private val pair: Parser[LogicalPlan] =
      (key ~ ("=".r ~> value).?).? ^^ {
        case None => SetCommand(None)
        case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
      }

    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
      case Success(plan, _) => plan
      case x => sys.error(x.toString)
    }
  }

  protected val AS = Keyword("AS")
  protected val CACHE = Keyword("CACHE")
  protected val CLEAR = Keyword("CLEAR")
  protected val DESCRIBE = Keyword("DESCRIBE")
  protected val EXTENDED = Keyword("EXTENDED")
  protected val FUNCTION = Keyword("FUNCTION")
  protected val FUNCTIONS = Keyword("FUNCTIONS")
  protected val IN = Keyword("IN")
  protected val LAZY = Keyword("LAZY")
  protected val SET = Keyword("SET")
  protected val SHOW = Keyword("SHOW")
  protected val TABLE = Keyword("TABLE")
  protected val TABLES = Keyword("TABLES")
  protected val UNCACHE = Keyword("UNCACHE")

  override protected lazy val start: Parser[LogicalPlan] =
    cache | uncache | set | show | desc | others

  private lazy val cache: Parser[LogicalPlan] =
    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
      case isLazy ~ tableName ~ plan =>
        CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
    }

  private lazy val uncache: Parser[LogicalPlan] =
    ( UNCACHE ~ TABLE ~> ident ^^ {
        case tableName => UncacheTableCommand(tableName)
      }
    | CLEAR ~ CACHE ^^^ ClearCacheCommand
    )

  private lazy val set: Parser[LogicalPlan] =
    SET ~> restInput ^^ {
      case input => SetCommandParser(input)
    }

  // It can be the following patterns:
  // SHOW FUNCTIONS;显示函数
  // SHOW FUNCTIONS mydb.func1;
  // SHOW FUNCTIONS func1;
  // SHOW FUNCTIONS `mydb.a`.`func1.aa`;
  private lazy val show: Parser[LogicalPlan] =
    ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
        case _ ~ dbName => ShowTablesCommand(dbName)
      }
    | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ {
        case Some(f) => ShowFunctions(f._1, Some(f._2))
        case None => ShowFunctions(None, None)
      }
    )

  private lazy val desc: Parser[LogicalPlan] =
    DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ {
      case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined)
    }

  private lazy val others: Parser[LogicalPlan] =
    wholeInput ^^ {
      case input => fallback(input)
    }

} 
Example 167
Source File: NullableColumnAccessorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.columnar

import java.nio.ByteBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
//试验可为空的列的访问
class TestNullableColumnAccessor[JvmType](
    buffer: ByteBuffer,
    columnType: ColumnType[JvmType])
  extends BasicColumnAccessor(buffer, columnType)
  with NullableColumnAccessor
//试验可为空的列的访问
object TestNullableColumnAccessor {
  def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType])
    : TestNullableColumnAccessor[JvmType] = {
    // Skips the column type ID
    buffer.getInt()
    new TestNullableColumnAccessor(buffer, columnType)
  }
}
//空列存取器套件
class NullableColumnAccessorSuite extends SparkFunSuite {
  import ColumnarTestUtils._

  Seq(
    BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
    STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
    .foreach {
    testNullableColumnAccessor(_)
  }
  //试验可为空的列的访问
  def testNullableColumnAccessor[JvmType](
      columnType: ColumnType[JvmType]): Unit = {
    //stripSuffix去掉<string>字串中结尾的字符
    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
    val nullRow = makeNullRow(1)
    //空值
    test(s"Nullable $typeName column accessor: empty column") {
      val builder = TestNullableColumnBuilder(columnType)
      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      assert(!accessor.hasNext)
    }
    //访问空值
    test(s"Nullable $typeName column accessor: access null values") {
      val builder = TestNullableColumnBuilder(columnType)
      val randomRow = makeRandomRow(columnType)

      (0 until 4).foreach { _ =>
        builder.appendFrom(randomRow, 0)
        builder.appendFrom(nullRow, 0)
      }

      val accessor = TestNullableColumnAccessor(builder.build(), columnType)
      val row = new GenericMutableRow(1)

      (0 until 4).foreach { _ =>
        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))

        assert(accessor.hasNext)
        accessor.extractTo(row, 0)
        assert(row.isNullAt(0))
      }

      assert(!accessor.hasNext)
    }
  }
} 
Example 168
Source File: ListTablesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
//列表测试套件
class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
  }

  test("get all tables") {//获得所有的表
      
      
    ctx.tables("DB").show()
    
    checkAnswer(
      //使用数据库,查找表名
      ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      //使用命令查询表名
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {//查询返回的数据集的表名
  //StructType代表一张表,StructField代表一个字段
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql(
            "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        ctx.dropTempTable("tables")
    }
  }
} 
Example 169
Source File: DDLSourceLoadSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
//请注意,这个服务必须改进这项工作的测试目录
//数据库定义语言数据源加载
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {//相同的名字的数据源
    intercept[RuntimeException] {
      caseInsensitiveContext.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {//从格式化别名加载数据源
    caseInsensitiveContext.read.format("gathering quorum").load().schema ==
    //StructType代表一张表,StructField代表一个字段
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {//重复的格式指定完整的类名
    caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without HiveContext") {//
    intercept[ClassNotFoundException] {
      caseInsensitiveContext.read.format("orc").load()
    }
  }
}

//假数据源之一
class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}
//假数据源之二
class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}
//假数据源之三
class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 170
Source File: LanguageAwareAnalyzer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.lucene.analysis.util.StopwordAnalyzerBase
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  def this() = this(Identifiable.randomUID("languageAnalyzer"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), stemmTextUDF(dataset.col($(inputColLang)), dataset.col($(inputColText)))).toDF
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputColText) equals $(outputCol)) {
      val schemaWithoutInput = new StructType(schema.fields.filterNot(_.name equals $(inputColText)))
      SchemaUtils.appendColumn(schemaWithoutInput, $(outputCol), ArrayType(StringType, true))
    } else {
      SchemaUtils.appendColumn(schema, $(outputCol), ArrayType(StringType, true))
    }
  }

}

object LanguageAwareAnalyzer extends DefaultParamsReadable[LanguageAwareAnalyzer] {
  override def load(path: String): LanguageAwareAnalyzer = super.load(path)
} 
Example 171
Source File: LanguageDetectorTransformer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import com.google.common.base.Optional
import com.optimaize.langdetect.LanguageDetector
import com.optimaize.langdetect.i18n.LdLocale
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}

import scala.collection.Map


  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("languageDetector"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), languageDetection(dataset.col($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer = {
    defaultCopy(extra)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.appendColumn(schema, $(outputCol), StringType)
  }

  @transient object languageDetectorWrapped extends Serializable {
    val languageDetector: LanguageDetector =
      LanguageDetectorUtils.buildLanguageDetector(
        LanguageDetectorUtils.readListLangsBuiltIn(),
        $(minimalConfidence),
        $(languagePriors).toMap)
  }

} 
Example 172
Source File: NGramExtractor.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair, ParamValidators, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  setDefault(new ParamPair[Int](upperN, 2), new ParamPair[Int](lowerN, 1))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val lowerBound = $(lowerN)
    val upperBound = $(upperN)
    val nGramUDF = udf[Seq[String], Seq[String]](NGramUtils.nGramFun(_,lowerBound,upperBound))
    dataset.withColumn($(outputCol), nGramUDF(dataset.col($(inputCol))))
  }


  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) != $(outputCol)) {
      schema.add($(outputCol), new ArrayType(StringType, true))
    } else {
      schema
    }
  }
}
object NGramExtractor extends DefaultParamsReadable[NGramExtractor] {
  override def load(path: String): NGramExtractor = super.load(path)
} 
Example 173
Source File: RegexpReplaceTransformer.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}


  def setInputCol(value: String): this.type = set(inputCol, value)

  def this() = this(Identifiable.randomUID("RegexpReplaceTransformer"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), regexp_replace(dataset.col($(inputCol)), $(regexpPattern), $(regexpReplacement)))
  }
  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) equals $(outputCol)) {
      val schemaWithoutInput = new StructType(schema.fields.filterNot(_.name equals $(inputCol)))
      SchemaUtils.appendColumn(schemaWithoutInput, $(outputCol), StringType)
    } else {
      SchemaUtils.appendColumn(schema, $(outputCol), StringType)
    }
  }

}

object RegexpReplaceTransformer extends DefaultParamsReadable[RegexpReplaceTransformer] {
  override def load(path: String): RegexpReplaceTransformer = super.load(path)
} 
Example 174
Source File: URLElimminator.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import org.apache.lucene.analysis.standard.UAX29URLEmailTokenizer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructType}


  def setInputCol(value: String): this.type = set(inputCol, value)

  def this() = this(Identifiable.randomUID("URLEliminator"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.withColumn($(outputCol), filterTextUDF(dataset.col($(inputCol))))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    if ($(inputCol) != $(outputCol)) {
      schema.add($(outputCol), StringType)
    } else {
      schema
    }
  }
}

object URLElimminator extends DefaultParamsReadable[URLElimminator] {
  override def load(path: String): URLElimminator = super.load(path)
} 
Example 175
Source File: NameAssigner.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCols
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.{DataFrame, Dataset, functions}
import org.apache.spark.sql.types.{Metadata, StringType, StructField, StructType}


class NameAssigner(override val uid: String) extends Transformer with HasInputCols{

  def setInputCols(column: String*) : this.type = set(inputCols, column.toArray)

  def this() = this(Identifiable.randomUID("NameAssigner"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    $(inputCols)

    $(inputCols).foldLeft(dataset.toDF)((data, column) => {
      val metadata: Metadata = dataset.schema(column).metadata
      val attributes = AttributeGroup.fromStructField(
        StructField(column, new VectorUDT, nullable = false, metadata = metadata))

      val map = attributes.attributes
        .map(arr => arr.filter(_.name.isDefined).map(a => a.index.get -> a.name.get).toMap)
        .getOrElse(Map())

      val func = functions.udf[String, Number](x => if(x == null) {
        null
      } else {
        val i = x.intValue()
        map.getOrElse(i, i.toString)
      })

      data.withColumn(column, func(data(column)).as(column, metadata))
    }).toDF
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType =
    StructType(schema.map(f => if ($(inputCols).contains(f.name)) {
      StructField(f.name, StringType, f.nullable, f.metadata)
    } else {
      f
    }))
} 
Example 176
Source File: RegressionEvaluator.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl


class RegressionEvaluator(override val uid: String) extends Evaluator[RegressionEvaluator](uid) {

  val throughOrigin = new BooleanParam(this, "throughOrigin",
    "True if the regression is through the origin. For example, in " +
      "linear regression, it will be true without fitting intercept.")

  def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value)

  def getThroughOrigin: Boolean = $(throughOrigin)

  def this() = this(Identifiable.randomUID("regressionEvaluator"))


  override def transform(dataset: Dataset[_]): DataFrame = {

    try {
      val predictions: RDD[(Double, Double)] = dataset.select($(predictionCol), $(labelCol))
        .rdd.map { case Row(score: Double, label: Double) => (score, label) }

      val metrics = Try(new RegressionMetrics(predictions))


      val rows = metrics.toOption.map(m => Seq(
        "r2" -> m.r2,
        "rmse" -> m.rootMeanSquaredError,
        "explainedVariance" -> m.explainedVariance,
        "meanAbsoluteError" -> m.meanAbsoluteError,
        "meanSquaredError" -> m.meanSquaredError
      ).map(Row.fromTuple)).getOrElse(Seq())

      SparkSqlUtils.reflectionLock.synchronized(
        dataset.sqlContext.createDataFrame(
          dataset.sparkSession.sparkContext.parallelize(rows, 1), transformSchema(dataset.schema)))
    } catch {
      // Most probably evaluation dataset is empty
      case e: Exception =>
        logWarning("Failed to calculate metrics due to " + e.getMessage)
        SparkSqlUtils.reflectionLock.synchronized(
          dataset.sqlContext.createDataFrame(
            dataset.sparkSession.sparkContext.emptyRDD[Row], transformSchema(dataset.schema)))
    }
  }

  override def copy(extra: ParamMap): RegressionEvaluator = {
    copyValues(new RegressionEvaluator(), extra)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    new StructType()
      .add("metric", StringType, nullable = false)
      .add("value", DoubleType, nullable = false)
  }
} 
Example 177
Source File: EWStatsTransformerSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.texts

import odkl.analysis.spark.TestEnv
import org.apache.spark.ml.odkl.texts.EWStatsTransformer
import org.apache.spark.ml.odkl.texts.EWStatsTransformer.EWStruct
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.scalatest.FlatSpec




class EWStatsTransformerSpec extends FlatSpec with TestEnv with org.scalatest.Matchers {

  import sqlc.implicits._

  case class dummyCase(Term: String, sig: Double, ewma: Double, ewmvar: Double)

  case class ewStruct(sig: Double, ewma: Double, ewmvar: Double) extends Serializable

  "CorrectEWFreqStatsTransformer" should "count existing and non-existing today words" in {

    val oldData = Seq(Seq("a", 0.0, 0.1, 0.01), Seq("b", 0.0, 0.2, 0.02), Seq("c", 0.0, 0.3, 0.015))

    val oldDF =
      sqlc.createDataFrame(sc.parallelize(oldData).map(f => {
        Row.fromSeq(f)
      }), new StructType().add("term", StringType)
        .add("sig", DoubleType).add("ewma", DoubleType).add("ewmvar", DoubleType))
    val rddRes = oldDF.rdd.
      map { case Row(term, sig, ewma, ewmvar) => Row(term, Row(sig, ewma, ewmvar)) }

    val schemaRes = StructType(
      StructField("term", StringType, false) ::
        StructField("ewStruct", StructType(
          StructField("sig", DoubleType, false) ::
          StructField("ewma", DoubleType, false) ::
          StructField("ewmvar", DoubleType, false) :: Nil
        ), true) :: Nil
    )
    val modernOldDF = sqlc.createDataFrame(rddRes, schemaRes)
      .withColumnRenamed("ewStruct", "old_EWStruct").withColumnRenamed("term", "old_Term")

    oldDF.collect()
    val fTransformer =
      new EWStatsTransformer()
        .setAlpha(0.7)
        .setBeta(0.055)
        .setInputFreqColName("Freq")
        .setInputTermColName("Term")
        .setOldEWStructColName("old_EWStruct")
        .setNewEWStructColName("EWStruct")
        .setOldTermColName("old_Term")
    val schema = new StructType().add("Term", StringType).add("Freq", DoubleType)

    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(("a", 0.2), ("b", 0.1), ("d", 0.05)))
        .map(f => {
          Row.fromSeq(Seq(f._1, f._2))
        }), schema)
    val joined = inDF.join(modernOldDF, $"Term" === $"old_Term", "outer")
    val outDF = fTransformer.transform(joined)
    val ans: Array[Row] = outDF.sort("Term").collect()
    assertResult(4)(ans.size)
  }

  "CorrectEWStatsTransformer" should "count EWStats correct" in {

    val mathTransformFun: (String, Double, Double, Double) => EWStruct = EWStatsTransformer.termEWStatsComputing(_:String,_:Double,_:Double,_:Double,0.7,0.005)
    val input = ("test", 0.01, 0.006, 0.003)
    val expected = (0.0669, 0.0088, 0.0009)
    val real = mathTransformFun(input._1, input._2, input._3, input._4)
    val realRounded = (Math.round(real.sig * 10000D) / 10000D, Math.round(real.ewma * 10000D) / 10000D, Math.round(real.ewmvar * 10000D) / 10000D)
    assertResult(expected)(realRounded)
  }
} 
Example 178
Source File: NGramExtractorSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.texts

import odkl.analysis.spark.TestEnv
import org.apache.spark.ml.odkl.texts.NGramExtractor
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
import org.scalatest.FlatSpec


class NGramExtractorSpec extends FlatSpec with TestEnv with org.scalatest.Matchers {

  "NGramExtractor" should "extract NGrams upTo=true" in {
    val nGramExtractor =
      new NGramExtractor()
        .setUpperN(2)
        .setInputCol("textTokenized")
        .setOutputCol("nGram")

    val schema = new StructType().add("textTokenized",ArrayType(StringType,true))
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq[String]("ab","bc","cd"),Seq[String]("a","b")))
        .map(f => {Row(f)}), schema)

    val outDF = nGramExtractor.transform(inDF)

    val outArrays = outDF.collect().map(_.getAs[Seq[String]]("nGram")).toSeq

    val correctArrays = Seq(Seq("ab","bc","cd","ab bc","bc cd"),Seq("a","b", "a b"))
    assertResult(correctArrays)(outArrays)
  }

  "NGramExtractor" should "extract NGrams upTo=false" in {
    val nGramExtractor =
      new NGramExtractor()
        .setUpperN(3)
        .setLowerN(3)
        .setInputCol("textTokenized")
        .setOutputCol("nGram")

    val schema = new StructType().add("textTokenized",ArrayType(StringType,true))
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq[String]("a","b","c","d")).map(f => {Row(f)})),
      schema)

    val outDF = nGramExtractor.transform(inDF)

    val outArrays = outDF.collect().map(_.getAs[Seq[String]]("nGram")).toSeq

    val correctArrays = Seq(Seq("a b c", "b c d"))
    assertResult(correctArrays)(outArrays)
  }
  "NGramExtractor" should "extract NGrams with the same col" in {
    val nGramExtractor =
      new NGramExtractor()
        .setUpperN(3)
        .setLowerN(3)
        .setInputCol("textTokenized")
        .setOutputCol("textTokenized")

    val schema = new StructType().add("textTokenized",ArrayType(StringType,true))
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq[String]("a","b","c","d")).map(f => {Row(f)})),
      schema)

    val outDF = nGramExtractor.transform(inDF)

    val outArrays = outDF.collect().map(_.getAs[Seq[String]]("textTokenized")).toSeq

    val correctArrays = Seq(Seq("a b c", "b c d"))
    assertResult(correctArrays)(outArrays)
  }

} 
Example 179
Source File: HashBasedDeduplicatorSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.texts

import odkl.analysis.spark.TestEnv
import org.apache.spark.ml.odkl.texts.HashBasedDeduplicator
import org.apache.spark.ml.linalg.{VectorUDT, Vectors}
import org.apache.spark.ml.odkl.MatrixUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{LongType, StringType, StructType}
import org.scalatest.FlatSpec

class HashBasedDeduplicatorSpec extends FlatSpec with TestEnv with org.scalatest.Matchers {
  "cotrect HashBasedDeduplicator " should " remove similar vectors based on hash " in {

    val vectorsSize = 10000

    val vector1 = (Vectors.sparse(vectorsSize, Array(5, 6, 7), Array(1.0, 1.0, 1.0)), 1L, "vector1")
    val vector2 = (Vectors.sparse(vectorsSize, Array(5, 6, 7), Array(1.0, 1.0, 0.0)), 1L, "vector2")
    val vector3 = (Vectors.sparse(vectorsSize, Array(5, 6, 7), Array(1.0, 0.0, 1.0)), 2L, "vector3") //pretty similar, but in 2nd bucket
    val vector4 = (Vectors.sparse(vectorsSize, Array(1, 2), Array(1.0, 1.0)), 1L, "vector4") //completly another but in 1-st bucket

    val schema = new StructType()
      .add("vector", MatrixUtils.vectorUDT)
      .add("hash", LongType)
      .add("alias", StringType)

    val dataFrame = sqlc.createDataFrame(sc.parallelize(Seq(vector1, vector2, vector3, vector4).map(Row.fromTuple(_))), schema)
    val deduplicator = new HashBasedDeduplicator()
      .setInputColHash("hash")
      .setInputColVector("vector")
      .setSimilarityTreshold(0.80)

   val answer = deduplicator.transform(dataFrame)
        .collect().map(row => (row.getLong(1), row.getString(2)))

    assert(answer.exists(_._2 == "vector1")) //should stay
    assert(!answer.exists(_._2 == "vector2")) //should be removed
    assert(answer.exists(_._2 == "vector3")) //should stay cause in other bucket (FalseNegative)
    assert(answer.exists(_._2 == "vector4")) //should stay cause different (FalsePositive)
  }
} 
Example 180
Source File: FreqStatsTransformerSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.texts

import odkl.analysis.spark.TestEnv
import org.apache.spark.ml.odkl.texts.FreqStatsTransformer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructType}
import org.scalatest.FlatSpec


class FreqStatsTransformerSpec extends FlatSpec with TestEnv with org.scalatest.Matchers {

  "FreqStatsTransformer" should "count freq" in {
    val fTransformer =  new FreqStatsTransformer()
      .setInputDataCol("data")
      .setOutputColFreq("Freq")
      .setOutputColTerm("Term")

    val schema = new StructType().add("data",ArrayType(StringType,true))
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq[String]("a","b","c"),Seq[String]("a","b","a")))
        .map(f => {Row(f)}), schema)

    val correctAns = Array[(String,Double)](("a",2D/5D),("b",2D/5D),("c",1D/5D))
    val realAns = fTransformer.transform(inDF).sort("Term").collect().map(f =>{(f.getAs[String]("Term"),f.getAs[Double]("Freq"))})
    assertResult(correctAns)(realAns)

  }
  "FreqStatsTransformer" should "filter freq by uni and bi treshold" in {
    val fTransformer =  new FreqStatsTransformer()
      .setInputDataCol("data")
      .setOutputColFreq("Freq")
      .setOutputColTerm("Term")
      .setTresholdArr(Array[Double](1.5D/8D,1.1D/8D))

    val schema = new StructType().add("data",ArrayType(StringType,true))
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq[String]("a","b","c","c a", "c a"),Seq[String]("a","b","a", "c a", "a b")))
        .map(f => {Row(f)}), schema)

    val correctAns = Array[(String,Double)](("a",2D/8D),("b",2D/8D),("c a",2D/8D))
    val realAnsDF = fTransformer.transform(inDF).sort("Term")
      val realAns = realAnsDF.collect().map(f =>{(f.getAs[String]("Term"),f.getAs[Double]("Freq"))})
    assertResult(correctAns)(realAns)

  }

  "FreqStatsTransformer" should "extract max timestamp by term" in {
    val fTransformer =  new FreqStatsTransformer()
      .setInputDataCol("data")
      .setOutputColFreq("Freq")
      .setOutputColTerm("Term")
        .setWithTimestamp(true)
        .setTimestampColumnName("timestamp")
      .setTresholdArr(Array[Double](1D/8D,1.1D/8D))

    val schema =
      new StructType().add("data",ArrayType(StringType,true)).add("timestamp",LongType)
    val inDF = sqlc.createDataFrame(
      sc.parallelize(Seq(Seq(Seq[String]("a","c","c a", "c a"),100L),Seq(Seq[String]("c a", "a b"),150L),Seq(Seq[String]("b"),200L)))
        .map(f => {Row.fromSeq(f)}), schema)

    inDF.collect()
    val correctAns = Array[(String,Double,Long)](("a",1D/6D,100L),("a b",1D/6D, 150L),("b",1D/6D,200L),
      ("c",1D/6D, 100L),("c a",2D/6D, 150L))
    val realAns = fTransformer.transform(inDF).sort("Term").collect().map(f =>{(f.getAs[String]("Term"),f.getAs[Double]("Freq"),f.getAs[Long]("timestamp"))})
    assertResult(correctAns)(realAns)
    assertResult(correctAns(1))(realAns(1))

  }
} 
Example 181
Source File: CustomSchemaTest.scala    From spark-sftp   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.sftp

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, _}
import org.scalatest.{BeforeAndAfterEach, FunSuite}


class CustomSchemaTest extends FunSuite with BeforeAndAfterEach {
  var ss: SparkSession = _

  val csvTypesMap = Map("ProposalId" -> IntegerType,
    "OpportunityId" -> StringType,
    "Clicks" -> LongType,
    "Impressions" -> LongType
  )

  val jsonTypesMap = Map("name" -> StringType,
    "age" -> IntegerType
  )

  override def beforeEach() {
    ss = SparkSession.builder().master("local").appName("Custom Schema Test").getOrCreate()
  }

  private def validateTypes(field : StructField, typeMap : Map[String, DataType]) = {
    val expectedType = typeMap(field.name)
    assert(expectedType == field.dataType)
  }

  private def columnArray(typeMap : Map[String, DataType]) : Array[StructField] = {
    val columns = typeMap.map(x => new StructField(x._1, x._2, true))

    val columnStruct = Array[StructField] ()
    columns.copyToArray(columnStruct)

    columnStruct
  }

  test ("Read CSV with custom schema") {
    val columnStruct = columnArray(csvTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/sample.csv").getPath
    val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, csvTypesMap))
  }

  test ("Read Json with custom schema") {
    val columnStruct = columnArray(jsonTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/people.json").getPath
    val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, jsonTypesMap))
  }

} 
Example 182
Source File: NGram.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  @Since("1.5.0")
  def getN: Int = $(n)

  setDefault(n -> 2)

  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

@Since("1.6.0")
object NGram extends DefaultParamsReadable[NGram] {

  @Since("1.6.0")
  override def load(path: String): NGram = super.load(path)
} 
Example 183
Source File: KafkaWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
import org.apache.spark.util.Utils


private[kafka010] object KafkaWriter extends Logging {
  val TOPIC_ATTRIBUTE_NAME: String = "topic"
  val KEY_ATTRIBUTE_NAME: String = "key"
  val VALUE_ATTRIBUTE_NAME: String = "value"

  override def toString: String = "KafkaWriter"

  def validateQuery(
      schema: Seq[Attribute],
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
      if (topic.isEmpty) {
        throw new AnalysisException(s"topic option required when no " +
          s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
          s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
      } else {
        Literal(topic.get, StringType)
      }
    ).dataType match {
      case StringType => // good
      case _ =>
        throw new AnalysisException(s"Topic type must be a String")
    }
    schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse(
      Literal(null, StringType)
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
    schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse(
      throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found")
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, kafkaParameters, topic)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close())
    }
  }
} 
Example 184
Source File: SparkExecuteStatementOperationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
  }

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
    assert(columns.get(1).getComment() == "")
  }
} 
Example 185
Source File: StatsEstimationTestBase.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen + 8 + 4
    case _ => colStat.avgLen
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 186
Source File: ScalaUDFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 187
Source File: GenerateUnsafeProjectionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class GenerateUnsafeProjectionSuite extends SparkFunSuite {
  test("Test unsafe projection string access pattern") {
    val dataType = (new StructType).add("a", StringType)
    val exprs = BoundReference(0, dataType, nullable = true) :: Nil
    val projection = GenerateUnsafeProjection.generate(exprs)
    val result = projection.apply(InternalRow(AlwaysNull))
    assert(!result.isNullAt(0))
    assert(result.getStruct(0, 1).isNullAt(0))
  }
}

object AlwaysNull extends InternalRow {
  override def numFields: Int = 1
  override def setNullAt(i: Int): Unit = {}
  override def copy(): InternalRow = this
  override def anyNull: Boolean = true
  override def isNullAt(ordinal: Int): Boolean = true
  override def update(i: Int, value: Any): Unit = notSupported
  override def getBoolean(ordinal: Int): Boolean = notSupported
  override def getByte(ordinal: Int): Byte = notSupported
  override def getShort(ordinal: Int): Short = notSupported
  override def getInt(ordinal: Int): Int = notSupported
  override def getLong(ordinal: Int): Long = notSupported
  override def getFloat(ordinal: Int): Float = notSupported
  override def getDouble(ordinal: Int): Double = notSupported
  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
  override def getUTF8String(ordinal: Int): UTF8String = notSupported
  override def getBinary(ordinal: Int): Array[Byte] = notSupported
  override def getInterval(ordinal: Int): CalendarInterval = notSupported
  override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
  override def getArray(ordinal: Int): ArrayData = notSupported
  override def getMap(ordinal: Int): MapData = notSupported
  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
  private def notSupported: Nothing = throw new UnsupportedOperationException
} 
Example 188
Source File: ComplexDataSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {
  def utf8(str: String): UTF8String = UTF8String.fromString(str)

  test("inequality tests for MapData") {
    // test data
    val testMap1 = Map(utf8("key1") -> 1)
    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
    val testMap3 = Map(utf8("key1") -> 1)
    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }

  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericRow = genericRow.copy()
    assert(copiedGenericRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedGenericRow.getString(0) == "a")
  }

  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val mutableRow = new SpecificInternalRow(Seq(StringType))
    mutableRow(0) = unsafeRow.getUTF8String(0)
    val copiedMutableRow = mutableRow.copy()
    assert(copiedMutableRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedMutableRow.getString(0) == "a")
  }

  test("GenericArrayData.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericArray = genericArray.copy()
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied array data should not be changed externally.
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
  }

  test("copy on nested complex type") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
    val copied = arrayOfRow.copy()
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied data should not be changed externally.
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
  }
} 
Example 189
Source File: LikeSimplificationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, StringType}

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("null pattern") {
    val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
  }
} 
Example 190
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 191
Source File: resources.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 192
Source File: GroupedIteratorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 193
Source File: SpreadsheetRelation.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.datasource.google.spreadsheet

import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService.SparkSpreadsheetContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import mimir.exec.spark.datasource.google.spreadsheet.SparkSpreadsheetService._

  private val fieldMap = scala.collection.mutable.Map[String, String]()
  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheet] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    val schemaMap = fieldMap.toMap
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else if (schemaMap.contains(field.name) && m.contains(schemaMap(field.name))) {
            TypeCast.castTo(m(schemaMap(field.name)), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  def sanitizeColumnName(name: String): String =
  {
    name
      .replaceAll("[^a-zA-Z0-9]+", "_")    // Replace sequences of non-alphanumeric characters with underscores
      .replaceAll("_+$", "")               // Strip trailing underscores
      .replaceAll("^[0-9_]+", "")          // Strip leading underscores and digits
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName => {
      val sanitizedName = sanitizeColumnName(fieldName)
      fieldMap.put(sanitizedName, fieldName)
      StructField(sanitizedName, StringType, true)
    }})

} 
Example 194
Source File: JsonGroupArray.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, StringType }
import org.apache.spark.sql.catalyst.expressions.{ 
  AttributeReference, 
  If, 
  StartsWith, 
  Literal, 
  IsNull, 
  Concat, 
  Substring
}

case class JsonGroupArray(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = StringType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function json_group_array")
  private lazy val json_group_array = AttributeReference("json_group_array", StringType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = json_group_array :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create("", StringType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    If(IsNull(child),
      Concat(Seq(json_group_array, Literal(","), Literal("null"))),
      Concat(Seq(json_group_array, Literal(","), org.apache.spark.sql.catalyst.expressions.Cast(child,StringType,None)))) 
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      Concat(Seq(json_group_array.left, json_group_array.right))
    )
  }
  override lazy val evaluateExpression = Concat(Seq(Literal("["), If(StartsWith(json_group_array,Literal(",")),Substring(json_group_array,Literal(2),Literal(Integer.MAX_VALUE)),json_group_array), Literal("]")))
} 
Example 195
Source File: ProtobufRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.protobuf.ProtobufConverter

class ProtobufRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {

  val labelColumnName = "label"
  val featuresColumnName = "features"
  val schema = StructType(Array(StructField(labelColumnName, DoubleType), StructField(
    featuresColumnName, VectorType)))

  it should "serialize a dense vector" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new ProtobufRequestRowSerializer(Some(schema))
    val protobuf = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Option.empty)
    val serialized = rrs.serializeRow(row)
    val protobufIterator = ProtobufConverter.recordIOByteArrayToProtobufs(serialized)
    val protobufFromRecordIO = protobufIterator.next

    assert(!protobufIterator.hasNext)
    assert(protobuf.equals(protobufFromRecordIO))
  }

  it should "serialize a sparse vector" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new ProtobufRequestRowSerializer(Some(schema))
    val protobuf = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Option.empty)
    val serialized = rrs.serializeRow(row)
    val protobufIterator = ProtobufConverter.recordIOByteArrayToProtobufs(serialized)
    val protobufFromRecordIO = protobufIterator.next

    assert(!protobufIterator.hasNext)
    assert(protobuf.equals(protobufFromRecordIO))
  }

  it should "fail to set schema on invalid features name" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    intercept[IllegalArgumentException] {
      val rrs = new ProtobufRequestRowSerializer(Some(schema), featuresColumnName = "doesNotExist")
    }
  }


  it should "fail on invalid types" in {
    val schemaWithInvalidFeaturesType = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", StringType, nullable = false)))
    intercept[RuntimeException] {
      new ProtobufRequestRowSerializer(Some(schemaWithInvalidFeaturesType))
    }
  }

  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    new ProtobufRequestRowSerializer(Some(validSchema))
  }
} 
Example 196
Source File: UnlabeledLibSVMRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class UnlabeledLibSVMRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {

  val schema = StructType(Array(StructField("features", SQLDataTypes.VectorType, nullable = false)))

  "UnlabeledLibSVMRequestRowSerializer" should "serialize sparse vector" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    val serialized = new String(rrs.serializeRow(row))
    assert ("0.0 1:-100.0 11:100.1\n" == serialized)
  }

  it should "serialize dense vector" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    val serialized = new String(rrs.serializeRow(row))
    assert("0.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "fail on invalid features column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs =
      new UnlabeledLibSVMRequestRowSerializer(featuresColumnName = "mangoes are not features")
    intercept[RuntimeException] {
      rrs.serializeRow(row)
    }
  }

  it should "fail on invalid features type" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row =
      new GenericRowWithSchema(values = Seq(1.0, "FEATURESSSSSZ!1!").toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    intercept[RuntimeException] {
      rrs.serializeRow(row)
    }
  }


  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("features", SQLDataTypes.VectorType, nullable = false)))

    val rrs = new UnlabeledLibSVMRequestRowSerializer(Some(validSchema))
  }

  it should "fail to validate incorrect schema" in {
    val invalidSchema = StructType(Array(
      StructField("features", StringType, nullable = false)))

    intercept[IllegalArgumentException] {
      new UnlabeledLibSVMRequestRowSerializer(Some(invalidSchema))
    }
  }
} 
Example 197
Source File: LibSVMRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest._
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.LibSVMResponseRowDeserializer

class LibSVMRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {
  val schema = new LibSVMResponseRowDeserializer(10).schema

  "LibSVMRequestRowSerializer" should "serialize sparse vector" in {

    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new LibSVMRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    assert ("1.0 1:-100.0 11:100.1\n" == serialized)
  }

  it should "serialize dense vector" in {

    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new LibSVMRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    assert("1.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "ignore other columns" in {
    val schemaWithExtraColumns = StructType(Array(
      StructField("name", StringType, nullable = false),
      StructField("label", DoubleType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false),
        StructField("favorite activity", StringType, nullable = false)))

    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq("Elizabeth", 1.0, vec, "Crying").toArray,
      schema = schemaWithExtraColumns)

    val rrs = new LibSVMRequestRowSerializer(Some(schemaWithExtraColumns))
    val serialized = new String(rrs.serializeRow(row))
    assert("1.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "fail on invalid features column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schema), featuresColumnName = "i do not exist dear sir!")
    }
  }

  it should "fail on invalid label column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schema),
        labelColumnName = "Sir! I must protest! I do not exist!")
    }
  }

  it should "fail on invalid types" in {
    val schemaWithInvalidLabelType = StructType(Array(
      StructField("label", StringType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schemaWithInvalidLabelType))
    }
    val schemaWithInvalidFeaturesType = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", StringType, nullable = false)))
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schemaWithInvalidFeaturesType))
    }
  }

  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    new LibSVMRequestRowSerializer(Some(validSchema))
  }
} 
Example 198
Source File: KinesisWriteTask.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import java.nio.ByteBuffer

import com.amazonaws.services.kinesis.producer.{KinesisProducer, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, StringType}

private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, String],
                                        inputSchema: Seq[Attribute]) extends Logging {

  private var producer: KinesisProducer = _
  private val projection = createProjection
  private val streamName = producerConfiguration.getOrElse(
    KinesisSourceProvider.SINK_STREAM_NAME_KEY, "")

  def execute(iterator: Iterator[InternalRow]): Unit = {
    producer = CachedKinesisProducer.getOrCreate(producerConfiguration)
    while (iterator.hasNext) {
      val currentRow = iterator.next()
      val projectedRow = projection(currentRow)
      val partitionKey = projectedRow.getString(0)
      val data = projectedRow.getBinary(1)

      sendData(partitionKey, data)
    }
  }

  def sendData(partitionKey: String, data: Array[Byte]): String = {
    var sentSeqNumbers = new String

    val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data))

    val kinesisCallBack = new FutureCallback[UserRecordResult]() {

      override def onFailure(t: Throwable): Unit = {
        logError(s"Writing to  $streamName failed due to ${t.getCause}")
      }

      override def onSuccess(result: UserRecordResult): Unit = {
        val shardId = result.getShardId
        sentSeqNumbers = result.getSequenceNumber
      }
    }
    Futures.addCallback(future, kinesisCallBack)

    producer.flushSync()
    sentSeqNumbers
  }

  def close(): Unit = {
    if (producer != null) {
      producer.flush()
      producer = null
    }
  }

  private def createProjection: UnsafeProjection = {

    val partitionKeyExpression = inputSchema
      .find(_.name == KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME).getOrElse(
      throw new IllegalStateException("Required attribute " +
        s"'${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME}' not found"))

    partitionKeyExpression.dataType match {
      case StringType | BinaryType => // ok
      case t =>
        throw new IllegalStateException(s"${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME} " +
          "attribute type must be a String or BinaryType")
    }

    val dataExpression = inputSchema.find(_.name == KinesisWriter.DATA_ATTRIBUTE_NAME).getOrElse(
      throw new IllegalStateException("Required attribute " +
        s"'${KinesisWriter.DATA_ATTRIBUTE_NAME}' not found")
    )

    dataExpression.dataType match {
      case StringType | BinaryType => // ok
      case t =>
        throw new IllegalStateException(s"${KinesisWriter.DATA_ATTRIBUTE_NAME} " +
          "attribute type must be a String or BinaryType")
    }

    UnsafeProjection.create(
      Seq(Cast(partitionKeyExpression, StringType), Cast(dataExpression, StringType)), inputSchema)
  }

} 
Example 199
Source File: NGram.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}


  def getN: Int = $(n)

  setDefault(n -> 2)

  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.sameType(ArrayType(StringType)),
      s"Input type must be ArrayType(StringType) but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

@Since("1.6.0")
object NGram extends DefaultParamsReadable[NGram] {

  @Since("1.6.0")
  override def load(path: String): NGram = super.load(path)
} 
Example 200
Source File: MiscFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
}