com.holdenkarau.spark.testing.DataFrameSuiteBase Scala Examples

The following examples show how to use com.holdenkarau.spark.testing.DataFrameSuiteBase. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: NestedCaseClassesTest.scala    From cleanframes   with Apache License 2.0 8 votes vote down vote up
package cleanframes

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.sql.functions
import org.scalatest.{FlatSpec, Matchers}


class NestedCaseClassesTest
  extends FlatSpec
    with Matchers
    with DataFrameSuiteBase {

  "Cleaner" should "compile and use a custom transformer for a custom type" in {
    import cleanframes.syntax._ // to use `.clean`
    import spark.implicits._

    // define test data for a dataframe
    val input = Seq(
      // @formatter:off
      ("1",           "1",          "1",           "1",           null),
      (null,          "2",          null,          "2",           "corrupted"),
      ("corrupted",   null,         "corrupted",   null,          "true"),
      ("4",           "corrupted",  "4",           "4",           "false"),
      ("5",           "5",          "5",           "corrupted",   "false"),
      ("6",           "6",          "6",           "6",           "true")
      // @formatter:on
    )
      // give column names that are known to you
      .toDF("col1", "col2", "col3", "col4", "col5")

    // import standard functions for conversions shipped with the library
    import cleanframes.instances.all._

    // !important: you need to give a new structure to allow to access sub elements
    val renamed = input.select(
      functions.struct(
        input.col("col1") as "a_col_1",
        input.col("col2") as "a_col_2"
      ) as "a",
      functions.struct(
        input.col("col3") as "b_col_1",
        input.col("col4") as "b_col_2"
      ) as "b",
      input.col("col5") as "c"
    )

    val result = renamed.clean[AB]
      .as[AB]
      .collect

    result should {
      contain theSameElementsAs Seq(
        // @formatter:off
        AB( A(Some(1), Some(1)),  B(Some(1),  Some(1.0)), Some(false)),
        AB( A(None,    Some(2)),  B(None,     Some(2.0)), Some(false)),
        AB( A(None,    None),     B(None,     None),      Some(true)),
        AB( A(Some(4), None),     B(Some(4),  Some(4.0)), Some(false)),
        AB( A(Some(5), Some(5)),  B(Some(5),  None),      Some(false)),
        AB( A(Some(6), Some(6)),  B(Some(6),  Some(6.0)), Some(true))
        // @formatter:on
      )
    }
  }

}

case class A(a_col_1: Option[Int], a_col_2: Option[Float])

case class B(b_col_1: Option[Float], b_col_2: Option[Double])

case class AB(a: A, b: B, c: Option[Boolean]) 
Example 2
Source File: SparkPFASuiteBase.scala    From aardpfark   with Apache License 2.0 6 votes vote down vote up
package com.ibm.aardpfark.pfa

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.SparkConf
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.scalactic.Equality
import org.scalatest.FunSuite

abstract class SparkPFASuiteBase extends FunSuite with DataFrameSuiteBase with PFATestUtils {

  val sparkTransformer: Transformer
  val input: Array[String]
  val expectedOutput: Array[String]

  val sparkConf =  new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID).
    set("spark.driver.host", "localhost")
  override lazy val spark = SparkSession.builder().config(sparkConf).getOrCreate()
  override val reuseContextIfPossible = true

  // Converts column containing a vector to an array
  def withColumnAsArray(df: DataFrame, colName: String) = {
    val vecToArray = udf { v: Vector => v.toArray }
    df.withColumn(colName, vecToArray(df(colName)))
  }

  def withColumnAsArray(df: DataFrame, first: String, others: String*) = {
    val vecToArray = udf { v: Vector => v.toArray }
    var result = df.withColumn(first, vecToArray(df(first)))
    others.foreach(c => result = result.withColumn(c, vecToArray(df(c))))
    result
  }

  // Converts column containing a vector to a sparse vector represented as a map
  def getColumnAsSparseVectorMap(df: DataFrame, colName: String) = {
    val vecToMap = udf { v: Vector => v.toSparse.indices.map(i => (i.toString, v(i))).toMap }
    df.withColumn(colName, vecToMap(df(colName)))
  }

}

abstract class Result

object ApproxEquality extends ApproxEquality

trait ApproxEquality {

  import org.scalactic.Tolerance._
  import org.scalactic.TripleEquals._

  implicit val seqApproxEq: Equality[Seq[Double]] = new Equality[Seq[Double]] {
    override def areEqual(a: Seq[Double], b: Any): Boolean = {
      b match {
        case d: Seq[Double] =>
          a.zip(d).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }

  implicit val vectorApproxEq: Equality[Vector] = new Equality[Vector] {
    override def areEqual(a: Vector, b: Any): Boolean = {
      b match {
        case v: Vector =>
          a.toArray.zip(v.toArray).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }
} 
Example 3
Source File: ColumnPruningSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.funsuite.AnyFunSuite


class ColumnPruningSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase {

  test("returns only required columns in query") {
    createDummyTable()

    val df = spark.read
      .format("com.exasol.spark")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE")
      .load()
      .select("city")

    assert(df.columns.size === 1)
    assert(df.columns.head === "city")
    val result = df.collect().map(x => x.getString(0)).toSet
    assert(result === Set("Berlin", "Paris", "Lisbon"))
  }

} 
Example 4
Source File: LongReadsTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.coverage

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{SequilaSession, SparkSession}
import org.biodatageeks.sequila.utils.{Columns, InternalParams, SequilaRegister}
import org.scalatest.{BeforeAndAfter, FunSuite}

class LongReadsTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  val bamPath: String =
    getClass.getResource("/nanopore_guppy_slice.bam").getPath
  val splitSize = 30000
  val tableNameBAM = "reads"

  before {

    System.setSecurityManager(null)
    spark.sql(s"DROP TABLE IF EXISTS $tableNameBAM")
    spark.sql(s"""
         |CREATE TABLE $tableNameBAM
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

  }
  test("BAM - Nanopore with guppy basecaller") {

    val session: SparkSession = SequilaSession(spark)
    SequilaRegister.register(session)
    session.sparkContext
      .setLogLevel("WARN")
    val bdg = session.sql(s"SELECT * FROM ${tableNameBAM}")
    assert(bdg.count() === 150)
  }

  test("BAM - coverage - Nanopore with guppy basecaller") {
    spark.sqlContext.setConf(InternalParams.InputSplitSize,
                             (splitSize * 10).toString)
    val session2: SparkSession = SequilaSession(spark)
    SequilaRegister.register(session2)
    val query =
      s"""SELECT ${Columns.CONTIG}, ${Columns.START}, ${Columns.COVERAGE}
        FROM bdg_coverage('$tableNameBAM','nanopore_guppy_slice','bases')
        order by ${Columns.CONTIG},${Columns.START},${Columns.END}
        """.stripMargin
    val covMultiPartitionDF = session2.sql(query)

    //covMultiPartitionDF.coalesce(1).write.mode("overwrite").option("delimiter", "\t").csv("/Users/aga/workplace/multiPart.csv")
    assert(covMultiPartitionDF.count() == 45620) // total count check 45620<---> 45842

    assert(covMultiPartitionDF.filter(s"${Columns.COVERAGE}== 0").count == 0)

    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5010515")
        .first()
        .getShort(2) == 1) // value check [first element]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5022667")
        .first()
        .getShort(2) == 15) // value check [partition boundary]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5036398")
        .first()
        .getShort(2) == 14) // value check [partition boundary]
    assert(
      covMultiPartitionDF
        .where(s"${Columns.CONTIG}='21' and ${Columns.START} == 5056356")
        .first()
        .getShort(2) == 1) // value check [last element]

  }

} 
Example 5
Source File: JoinOrderTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.rangejoins

import java.io.{OutputStreamWriter, PrintWriter}

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
  IntegerType,
  StringType,
  StructField,
  StructType
}
import org.bdgenomics.utils.instrumentation.{
  Metrics,
  MetricsListener,
  RecordedMetrics
}
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim
import org.scalatest.{BeforeAndAfter, FunSuite}

class JoinOrderTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  val schema = StructType(
    Seq(StructField("chr", StringType),
        StructField("start", IntegerType),
        StructField("end", IntegerType)))
  val metricsListener = new MetricsListener(new RecordedMetrics())
  val writer = new PrintWriter(new OutputStreamWriter(System.out))
  before {
    System.setSecurityManager(null)
    spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim(
      spark) :: Nil
    Metrics.initialize(sc)
    val rdd1 = sc
      .textFile(getClass.getResource("/refFlat.txt.bz2").getPath)
      .map(r => r.split('\t'))
      .map(
        r =>
          Row(
            r(2).toString,
            r(4).toInt,
            r(5).toInt
        ))
    val ref = spark.createDataFrame(rdd1, schema)
    ref.createOrReplaceTempView("ref")

    val rdd2 = sc
      .textFile(getClass.getResource("/snp150Flagged.txt.bz2").getPath)
      .map(r => r.split('\t'))
      .map(
        r =>
          Row(
            r(1).toString,
            r(2).toInt,
            r(3).toInt
        ))
    val snp = spark
      .createDataFrame(rdd2, schema)
    snp.createOrReplaceTempView("snp")
  }

  test("Join order - broadcasting snp table") {
    spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder",
                             "true")
    val query =
      s"""
         |SELECT snp.*,ref.* FROM ref JOIN snp
         |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end)
       """.stripMargin

    assert(spark.sql(query).count === 616404L)

  }

  test("Join order - broadcasting ref table") {
    spark.sqlContext.setConf("spark.biodatageeks.rangejoin.useJoinOrder",
                             "true")
    val query =
      s"""
         |SELECT snp.*,ref.* FROM snp JOIN ref
         |ON (ref.chr=snp.chr AND snp.end>=ref.start AND snp.start<=ref.end)
       """.stripMargin
    assert(spark.sql(query).count === 616404L)

  }
  after {
    Metrics.print(writer, Some(metricsListener.metrics.sparkMetrics.stageTimes))
    writer.flush()
    Metrics.stopRecording()
  }
} 
Example 6
Source File: FeatureCountsTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.rangejoins

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import htsjdk.samtools.ValidationStringency
import org.apache.hadoop.io.LongWritable
import org.biodatageeks.sequila.apps.FeatureCounts.Region
import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim
import org.biodatageeks.sequila.utils.{Columns, DataQualityFuncs}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.seqdoop.hadoop_bam.util.SAMHeaderReader
import org.seqdoop.hadoop_bam.{BAMInputFormat, SAMRecordWritable}



class FeatureCountsTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  before {
    System.setSecurityManager(null)
    spark.experimental.extraStrategies = new IntervalTreeJoinStrategyOptim(
      spark) :: Nil
  }

  test("Feature counts for chr1:20138-20294") {
    val query = s"""
        | SELECT count(*),targets.${Columns.CONTIG},targets.${Columns.START},targets.${Columns.END}
        | FROM reads JOIN targets
        |ON (
        |  targets.${Columns.CONTIG}=reads.${Columns.CONTIG}
        |  AND
        |  reads.${Columns.END} >= targets.${Columns.START}
        |  AND
        |  reads.${Columns.START} <= targets.${Columns.END}
        |)
        | GROUP BY targets.${Columns.CONTIG},targets.${Columns.START},targets.${Columns.END}
        | HAVING ${Columns.CONTIG}='1' AND ${Columns.START} = 20138 AND ${Columns.END} = 20294""".stripMargin

    spark.sparkContext.hadoopConfiguration.set(
      SAMHeaderReader.VALIDATION_STRINGENCY_PROPERTY,
      ValidationStringency.SILENT.toString)

    val alignments = spark.sparkContext
      .newAPIHadoopFile[LongWritable, SAMRecordWritable, BAMInputFormat](
        getClass.getResource("/NA12878.slice.bam").getPath)
      .map(_._2.get)
      .map(r => Region(DataQualityFuncs.cleanContig(r.getContig), r.getStart, r.getEnd))

    val reads = spark.sqlContext
      .createDataFrame(alignments)
      .withColumnRenamed("contigName", Columns.CONTIG)
      .withColumnRenamed("start", Columns.START)
      .withColumnRenamed("end", Columns.END)

    reads.createOrReplaceTempView("reads")

    val targets = spark.sqlContext
      .createDataFrame(Array(Region("1", 20138, 20294)))
      .withColumnRenamed("contigName", Columns.CONTIG)
      .withColumnRenamed("start", Columns.START)
      .withColumnRenamed("end", Columns.END)

    targets.createOrReplaceTempView("targets")

    spark.sql(query).explain(false)
    assert(spark.sql(query).first().getLong(0) === 1484L)

  }

} 
Example 7
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = "NA12878.multichrom.md"
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
    List(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
    )
  )
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    spark
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
    spark.sql(
      s"""
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
    spark.sql(
      s"""
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )
         |
      """.stripMargin)

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
        "null"
      else
        map.map({
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")
    }

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)
  }

} 
Example 8
Source File: VCFDataSourceTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.datasources

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.biodatageeks.sequila.utils.Columns
import org.scalatest.{BeforeAndAfter, FunSuite}

class VCFDataSourceTestSuite
    extends FunSuite
    with DataFrameSuiteBase
    with BeforeAndAfter
    with SharedSparkContext {

  val vcfPath: String = getClass.getResource("/vcf/test.vcf").getPath
  val tableNameVCF = "variants"
  before {
    spark.sql(s"DROP TABLE IF EXISTS $tableNameVCF")
    spark.sql(s"""
         |CREATE TABLE $tableNameVCF
         |USING org.biodatageeks.sequila.datasources.VCF.VCFDataSource
         |OPTIONS(path "$vcfPath")
         |
      """.stripMargin)

  }
  test("VCF - Row count VCFDataSource") {
    val query = s"SELECT * FROM $tableNameVCF"
    spark
      .sql(query)
      .printSchema()

    assert(
      spark
        .sql(query)
        .first()
        .getString(0) === "20")

    assert(spark.sql(query).count() === 7L)

  }

  after {
    spark.sql(s"DROP TABLE IF EXISTS  $tableNameVCF")
  }

} 
Example 9
Source File: BEDBaseTestSuite.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.base

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}

class BEDBaseTestSuite
    extends
      FunSuite
    with DataFrameSuiteBase
    with SharedSparkContext with BeforeAndAfter{

  val bedPath: String = getClass.getResource("/bed/test.bed").getPath
  val tableNameBED = "targets"

  val bedSimplePath: String = getClass.getResource("/bed/simple.bed").getPath
  val tableNameSimpleBED = "simple_targets"


  before{
    spark.sql(s"DROP TABLE IF EXISTS $tableNameBED")
    spark.sql(s"""
         |CREATE TABLE $tableNameBED
         |USING org.biodatageeks.sequila.datasources.BED.BEDDataSource
         |OPTIONS(path "$bedPath")
         |
      """.stripMargin)
    spark.sql(s"DROP TABLE IF EXISTS $tableNameSimpleBED")
    spark.sql(s"""
                 |CREATE TABLE $tableNameSimpleBED
                 |USING org.biodatageeks.sequila.datasources.BED.BEDDataSource
                 |OPTIONS(path "$bedSimplePath")
                 |
      """.stripMargin)

  }

  def after = {

    spark.sql(s"DROP TABLE IF EXISTS $tableNameBED")
    spark.sql(s"DROP TABLE IF EXISTS $tableNameSimpleBED")

  }


} 
Example 10
Source File: TestWithSpark.scala    From ZparkIO   with MIT License 5 votes vote down vote up
package com.leobenkel.zparkiotest

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.SparkConf
import org.scalatest.Suite

trait TestWithSpark extends DataFrameSuiteBase { self: Suite =>
  override protected val reuseContextIfPossible: Boolean = true
  override protected val enableHiveSupport:      Boolean = false

  
  def enableSparkUI: Boolean = {
    false
  }

  final override def conf: SparkConf = {
    if (enableSparkUI) {
      super.conf
        .set("spark.ui.enabled", "true")
        .set("spark.ui.port", "4050")
    } else {
      super.conf
    }
  }
} 
Example 11
Source File: ExasolRelationSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

import com.exasol.spark.util.ExasolConnectionManager

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.mockito.Mockito._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

class ExasolRelationSuite
    extends AnyFunSuite
    with Matchers
    with MockitoSugar
    with DataFrameSuiteBase {

  test("buildScan returns RDD of empty Row-s when requiredColumns is empty (count pushdown)") {
    val query = "SELECT 1"
    val cntQuery = "SELECT COUNT(*) FROM (SELECT 1) A "
    val cnt = 5L

    val manager = mock[ExasolConnectionManager]
    when(manager.withCountQuery(cntQuery)).thenReturn(cnt)

    val relation = new ExasolRelation(spark.sqlContext, query, Option(new StructType), manager)
    val rdd = relation.buildScan()

    assert(rdd.isInstanceOf[RDD[Row]])
    assert(rdd.partitions.size === 4)
    assert(rdd.count === cnt)
    verify(manager, times(1)).withCountQuery(cntQuery)
  }

  test("unhandledFilters should keep non-pushed filters") {
    val schema: StructType = new StructType()
      .add("a", BooleanType)
      .add("b", StringType)
      .add("c", IntegerType)

    val filters = Array[Filter](
      LessThanOrEqual("c", "3"),
      EqualTo("b", "abc"),
      Not(EqualTo("a", false))
    )

    val nullFilters = Array(EqualNullSafe("b", "xyz"))

    val rel = new ExasolRelation(spark.sqlContext, "", Option(schema), null)

    assert(rel.unhandledFilters(filters) === Array.empty[Filter])
    assert(rel.unhandledFilters(filters ++ nullFilters) === nullFilters)
  }

} 
Example 12
Source File: TypesSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.sql.types._
import org.scalatest.funsuite.AnyFunSuite

class TypesSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase {

  test("converts Exasol types to Spark") {
    createAllTypesTable()

    val df = spark.read
      .format("com.exasol.spark")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_ALL_TYPES_TABLE")
      .load()

    val schemaTest = df.schema

    val schemaExpected = Map(
      "MYID" -> LongType,
      "MYTINYINT" -> ShortType,
      "MYSMALLINT" -> IntegerType,
      "MYBIGINT" -> DecimalType(36, 0),
      "MYDECIMALMAX" -> DecimalType(36, 36),
      "MYDECIMALSYSTEMDEFAULT" -> LongType,
      "MYNUMERIC" -> DecimalType(5, 2),
      "MYDOUBLE" -> DoubleType,
      "MYCHAR" -> StringType,
      "MYNCHAR" -> StringType,
      "MYLONGVARCHAR" -> StringType,
      "MYBOOLEAN" -> BooleanType,
      "MYDATE" -> DateType,
      "MYTIMESTAMP" -> TimestampType,
      "MYGEOMETRY" -> StringType,
      "MYINTERVAL" -> StringType
    )

    val fields = schemaTest.toList
    fields.foreach(field => {
      assert(field.dataType === schemaExpected.get(field.name).get)
    })
  }

} 
Example 13
Source File: PredicatePushdownSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark

import java.sql.Timestamp

import org.apache.spark.sql.functions.col

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.funsuite.AnyFunSuite


class PredicatePushdownSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase {

  test("with where clause build from filters: filter") {
    createDummyTable()

    import spark.implicits._

    val df = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE")
      .load()
      .filter($"id" < 3)
      .filter(col("city").like("Ber%"))
      .select("id", "city")

    val result = df.collect().map(x => (x.getLong(0), x.getString(1))).toSet
    assert(result.size === 1)
    assert(result === Set((1, "Berlin")))
  }

  test("with where clause build from filters: createTempView and spark.sql") {
    createDummyTable()

    val df = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE")
      .load()

    df.createOrReplaceTempView("myTable")

    val myDF = spark
      .sql("SELECT id, city FROM myTable WHERE id BETWEEN 1 AND 3 AND name < 'Japan'")

    val result = myDF.collect().map(x => (x.getLong(0), x.getString(1))).toSet
    assert(result.size === 2)
    assert(result === Set((1, "Berlin"), (2, "Paris")))
  }

  test("date and timestamp should be read and filtered correctly") {
    import java.sql.Date

    createDummyTable()
    val df = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT date_info, updated_at FROM $EXA_SCHEMA.$EXA_TABLE")
      .load()
    val minTimestamp = Timestamp.valueOf("2017-12-30 00:00:00.0000")
    val testDate = Date.valueOf("2017-12-31")

    val resultDate = df.collect().map(_.getDate(0))
    assert(resultDate.contains(testDate))

    val resultTimestamp = df.collect().map(_.getTimestamp(1)).map(x => x.after(minTimestamp))
    assert(!resultTimestamp.contains(false))

    val filteredByDateDF = df.filter(col("date_info") === testDate)
    assert(filteredByDateDF.count() === 1)

    val filteredByTimestampDF = df.filter(col("updated_at") < minTimestamp)
    assert(filteredByTimestampDF.count() === 0)
  }

  test("count should be performed successfully") {
    createDummyTable()
    val df = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $EXA_SCHEMA.$EXA_TABLE")
      .load()
    val result = df.count()
    assert(result === 3)
  }
} 
Example 14
Source File: ReservedKeywordsSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.funsuite.AnyFunSuite


class ReservedKeywordsSuite extends AnyFunSuite with BaseDockerSuite with DataFrameSuiteBase {

  val SCHEMA: String = "RESERVED_KEYWORDS"
  val TABLE: String = "TEST_TABLE"

  test("queries a table with reserved keyword") {
    createTable()

    val expected = Set("True", "False", "Checked")

    val df1 = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"""SELECT "CONDITION" FROM $SCHEMA.$TABLE""")
      .load()

    assert(df1.collect().map(x => x(0)).toSet === expected)

    val df2 = spark.read
      .format("exasol")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $SCHEMA.$TABLE")
      .load()
      .select("condition")

    assert(df2.collect().map(x => x(0)).toSet === expected)
  }

  ignore("queries a table with reserved keyword using where clause") {
    createTable()

    val df = spark.read
      .format("com.exasol.spark")
      .option("host", container.host)
      .option("port", s"${container.port}")
      .option("query", s"SELECT * FROM $SCHEMA.$TABLE")
      .load()
      .select(s""""CONDITION"""")
      .where(s""""CONDITION" LIKE '%Check%'""")

    assert(df.collect().map(x => x(0)).toSet === Set("Checked"))
  }

  def createTable(): Unit =
    exaManager.withExecute(
      Seq(
        s"DROP SCHEMA IF EXISTS $SCHEMA CASCADE",
        s"CREATE SCHEMA $SCHEMA",
        s"""|CREATE OR REPLACE TABLE $SCHEMA.$TABLE (
            |   ID INTEGER IDENTITY NOT NULL,
            |   "CONDITION" VARCHAR(100) UTF8
            |)""".stripMargin,
        s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('True')""",
        s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('False')""",
        s"""INSERT INTO $SCHEMA.$TABLE ("CONDITION") VALUES ('Checked')""",
        "commit"
      )
    )

} 
Example 15
Source File: StreamingKMeansSuite.scala    From spark-structured-streaming-ml   with Apache License 2.0 5 votes vote down vote up
package com.highperformancespark.examples.structuredstreaming

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.scalatest.FunSuite
import org.apache.log4j.{Level, Logger}

case class TestRow(features: Vector)

class StreamingKMeansSuite extends FunSuite with DataFrameSuiteBase {

  override def beforeAll(): Unit = {
    super.beforeAll()
    Logger.getLogger("org").setLevel(Level.OFF)
  }

  test("streaming model with one center should converge to true center") {
    import spark.implicits._
    val k = 1
    val dim = 5
    val clusterSpread = 0.1
    val seed = 63
    // TODO: this test is very flaky. The centers do not converge for some
    // (most?) random seeds
    val (batches, trueCenters) =
      StreamingKMeansSuite.generateBatches(100, 80, k, dim, clusterSpread, seed)
    val inputStream = MemoryStream[TestRow]
    val ds = inputStream.toDS()
    val skm = new StreamingKMeans().setK(k).setRandomCenters(dim, 0.01)
    val query = skm.evilTrain(ds.toDF())
    val streamingModels = batches.map { batch =>
      inputStream.addData(batch)
      query.processAllAvailable()
      skm.getModel
    }
    // TODO: use spark's testing suite
    streamingModels.last.centers.zip(trueCenters).foreach {
      case (center, trueCenter) =>
        val centers = center.toArray.mkString(",")
        val trueCenters = trueCenter.toArray.mkString(",")
        println(s"${centers} | ${trueCenters}")
        assert(center.toArray.zip(trueCenter.toArray).forall(
          x => math.abs(x._1 - x._2) < 0.1))
    }
    query.stop()
  }

  def compareBatchAndStreaming(
      batchModel: KMeansModel,
      streamingModel: StreamingKMeansModel,
      validationData: DataFrame): Unit = {
    assert(batchModel.clusterCenters === streamingModel.centers)
    // TODO: implement prediction comparison
  }

}

object StreamingKMeansSuite {

  def generateBatches(
      numPoints: Int,
      numBatches: Int,
      k: Int,
      d: Int,
      r: Double,
      seed: Int,
      initCenters: Array[Vector] = null):
      (IndexedSeq[IndexedSeq[TestRow]], Array[Vector]) = {
    val rand = scala.util.Random
    rand.setSeed(seed)
    val centers = initCenters match {
      case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
      case _ => initCenters
    }
    val data = (0 until numBatches).map { i =>
      (0 until numPoints).map { idx =>
        val center = centers(idx % k)
        val vec = Vectors.dense(
          Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
        TestRow(vec)
      }
    }
    (data, centers)
  }
} 
Example 16
Source File: CustomSinkSuite.scala    From spark-structured-streaming-ml   with Apache License 2.0 5 votes vote down vote up
package com.highperformancespark.examples.structuredstreaming

import com.holdenkarau.spark.testing.DataFrameSuiteBase

import scala.collection.mutable.ListBuffer

import org.scalatest.FunSuite

import org.apache.spark._
import org.apache.spark.sql.{Dataset, DataFrame, Encoder, SQLContext}
import org.apache.spark.sql.execution.streaming.MemoryStream

class CustomSinkSuite extends FunSuite with DataFrameSuiteBase {

  test("really simple test of the custom sink") {
    import spark.implicits._
    val input = MemoryStream[String]
    val doubled = input.toDS().map(x => x + " " + x)
    val formatName = ("com.highperformancespark.examples" +
      "structuredstreaming.CustomSinkCollectorProvider")
    val query = doubled.writeStream
      .queryName("testCustomSinkBasic")
      .format(formatName)
      .start()
    val inputData = List("hi", "holden", "bye", "pandas")
    input.addData(inputData)
    assert(query.isActive === true)
    query.processAllAvailable()
    assert(query.exception === None)
    assert(Pandas.results(0) === inputData.map(x => x + " " + x))
  }
}

object Pandas{
  val results = new ListBuffer[Seq[String]]()
}

class CustomSinkCollectorProvider extends ForeachDatasetSinkProvider {
  override def func(df: DataFrame) {
    val spark = df.sparkSession
    import spark.implicits._
    Pandas.results += df.as[String].rdd.collect()
  }
} 
Example 17
Source File: EncryptedReadSuite.scala    From spark-excel   with Apache License 2.0 5 votes vote down vote up
package com.crealytics.spark.excel

import org.apache.spark.sql._
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

object EncryptedReadSuite {
  val simpleSchema = StructType(
    List(
      StructField("A", DoubleType, true),
      StructField("B", DoubleType, true),
      StructField("C", DoubleType, true),
      StructField("D", DoubleType, true)
    )
  )

  val expectedData = List(Row(1.0d, 2.0d, 3.0d, 4.0d)).asJava
}

class EncryptedReadSuite extends AnyFunSpec with DataFrameSuiteBase with Matchers {
  import EncryptedReadSuite._

  lazy val expected = spark.createDataFrame(expectedData, simpleSchema)

  def readFromResources(path: String, password: String, maxRowsInMemory: Option[Int] = None): DataFrame = {
    val url = getClass.getResource(path)
    val reader = spark.read
      .excel(
        dataAddress = s"Sheet1!A1",
        treatEmptyValuesAsNulls = true,
        workbookPassword = password,
        inferSchema = true
      )
    val withMaxRows = maxRowsInMemory.fold(reader)(rows => reader.option("maxRowsInMemory", s"$rows"))
    withMaxRows.load(url.getPath)
  }

  describe("spark-excel") {
    it("should read encrypted xslx file") {
      val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba")

      assertDataFrameEquals(expected, df)
    }

    it("should read encrypted xlsx file with maxRowsInMem=10") {
      val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba", maxRowsInMemory = Some(10))

      assertDataFrameEquals(expected, df)
    }

    it("should read encrypted xlsx file with maxRowsInMem=1") {
      val df = readFromResources("/spreadsheets/simple_encrypted.xlsx", "fooba", maxRowsInMemory = Some(1))

      assertDataFrameEquals(expected, df)
    }

    it("should read encrypted xls file") {
      val df = readFromResources("/spreadsheets/simple_encrypted.xls", "fooba")

      assertDataFrameEquals(expected, df)
    }
  }
} 
Example 18
Source File: ProcessTest.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.s2jobs.task

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.FunSuite

class ProcessTest extends FunSuite with DataFrameSuiteBase {

  test("SqlProcess execute sql") {
    import spark.implicits._
    
    val inputDF = Seq(
      ("a", "b", "friend"),
      ("a", "c", "friend"),
      ("a", "d", "friend")
    ).toDF("from", "to", "label")

    val inputMap = Map("input" -> inputDF)
    val sql = "SELECT * FROM input WHERE to = 'b'"
    val conf = TaskConf("test", "sql", Seq("input"), Map("sql" -> sql))

    val process = new SqlProcess(conf)

    val rstDF = process.execute(spark, inputMap)
    val tos = rstDF.collect().map{ row => row.getAs[String]("to")}

    assert(tos.size == 1)
    assert(tos.head == "b")
  }
} 
Example 19
Source File: WalLogAggregateProcessTest.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.s2jobs.wal.process

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

class WalLogAggregateProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {
  import org.apache.s2graph.s2jobs.wal.TestData._

  test("test entire process") {
    import spark.sqlContext.implicits._

    val edges = spark.createDataset(walLogsLs).toDF()
    val processKey = "agg"
    val inputMap = Map(processKey -> edges)

    val taskConf = new TaskConf(name = "test", `type` = "agg", inputs = Seq(processKey),
      options = Map("maxNumOfEdges" -> "10")
    )

    val job = new WalLogAggregateProcess(taskConf = taskConf)
    val processed = job.execute(spark, inputMap)

    processed.printSchema()
    processed.orderBy("from").as[WalLogAgg].collect().zip(aggExpected).foreach { case (real, expected) =>
      real shouldBe expected
    }
  }

} 
Example 20
Source File: BuildTopFeaturesProcessTest.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.s2jobs.wal.process

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal.DimValCountRank
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

class BuildTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {

  import org.apache.s2graph.s2jobs.wal.TestData._

  test("test entire process.") {
    import spark.implicits._
    val df = spark.createDataset(aggExpected).toDF()

    val taskConf = new TaskConf(name = "test", `type` = "test", inputs = Seq("input"),
      options = Map("minUserCount" -> "0")
    )
    val job = new BuildTopFeaturesProcess(taskConf)


    val inputMap = Map("input" -> df)
    val featureDicts = job.execute(spark, inputMap)
      .orderBy("dim", "rank")
      .map(DimValCountRank.fromRow)
      .collect()

    featureDicts shouldBe featureDictExpected

  }
} 
Example 21
Source File: FilterTopFeaturesProcessTest.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.s2jobs.wal.process

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal.transformer.DefaultTransformer
import org.apache.s2graph.s2jobs.wal.{DimValCountRank, WalLogAgg}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

class FilterTopFeaturesProcessTest extends FunSuite with Matchers with BeforeAndAfterAll with DataFrameSuiteBase {
  import org.apache.s2graph.s2jobs.wal.TestData._

  test("test filterTopKsPerDim.") {
    import spark.implicits._
    val featureDf = spark.createDataset(featureDictExpected).map { x =>
      (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
    }.toDF("dim", "value", "count", "rank")

    val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int])

    // filter nothing because all feature has rank < 10
    val filtered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10)

    val real = filtered.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect()
    real.zip(featureDictExpected).foreach { case (real, expected) =>
        real shouldBe expected
    }
    // filter rank >= 2
    val filtered2 = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 2)
    val real2 = filtered2.orderBy("dim", "rank").map(DimValCountRank.fromRow).collect()
    real2 shouldBe featureDictExpected.filter(_.rank < 2)
  }


  test("test filterWalLogAgg.") {
    import spark.implicits._
    val walLogAgg = spark.createDataset(aggExpected)
    val featureDf = spark.createDataset(featureDictExpected).map { x =>
      (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
    }.toDF("dim", "value", "count", "rank")
    val maxRankPerDim = spark.sparkContext.broadcast(Map.empty[String, Int])

    val transformers = Seq(DefaultTransformer(TaskConf.Empty))
    // filter nothing. so input, output should be same.
    val featureFiltered = FilterTopFeaturesProcess.filterTopKsPerDim(featureDf, maxRankPerDim, 10)
    val validFeatureHashKeys = FilterTopFeaturesProcess.collectDistinctFeatureHashes(spark, featureFiltered)
    val validFeatureHashKeysBCast = spark.sparkContext.broadcast(validFeatureHashKeys)
    val real = FilterTopFeaturesProcess.filterWalLogAgg(spark, walLogAgg, transformers, validFeatureHashKeysBCast)
      .collect().sortBy(_.from)

    real.zip(aggExpected).foreach { case (real, expected) =>
      real shouldBe expected
    }
  }

  test("test entire process. filter nothing.") {
    import spark.implicits._
    val df = spark.createDataset(aggExpected).toDF()
    val featureDf = spark.createDataset(featureDictExpected).map { x =>
      (x.dimVal.dim, x.dimVal.value, x.count, x.rank)
    }.toDF("dim", "value", "count", "rank")

    val inputKey = "input"
    val featureDictKey = "feature"
    // filter nothing since we did not specified maxRankPerDim and defaultMaxRank.
    val taskConf = new TaskConf(name = "test", `type` = "test",
      inputs = Seq(inputKey, featureDictKey),
      options = Map(
        "featureDict" -> featureDictKey,
        "walLogAgg" -> inputKey
      )
    )
    val inputMap = Map(inputKey -> df, featureDictKey -> featureDf)
    val job = new FilterTopFeaturesProcess(taskConf)
    val filtered = job.execute(spark, inputMap)
      .orderBy("from")
      .as[WalLogAgg]
      .collect()

    filtered.zip(aggExpected).foreach { case (real, expected) =>
      real shouldBe expected
    }

  }
} 
Example 22
Source File: WordCountTest.scala    From robin-sparkles   with Apache License 2.0 5 votes vote down vote up
package com.highperformancespark.robinsparkles
//import com.highperformancespark.robinsparkles.listener._



import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.scalatest.FunSuite

class WordCountTest extends FunSuite with DataFrameSuiteBase {
  test("word count with Stop Words Removed"){
    // TODO: Add listener

    val linesRDD = sc.parallelize(Seq(
      "How happy was the panda? You ask.",
      "Panda is the most happy panda in all the#!?ing land!"))

    val stopWords: Set[String] = Set("a", "the", "in", "was", "there", "she", "he")
    val splitTokens: Array[Char] = "#%?!. ".toCharArray

    val wordCounts = WordCount.withStopWordsFiltered(
      linesRDD, splitTokens, stopWords)
    val wordCountsAsMap = wordCounts.collectAsMap()
    assert(!wordCountsAsMap.contains("the"))
    assert(!wordCountsAsMap.contains("?"))
    assert(!wordCountsAsMap.contains("#!?ing"))
    assert(wordCountsAsMap.contains("ing"))
    assert(wordCountsAsMap.get("panda").get.equals(3))
  }
} 
Example 23
Source File: OptionalPrimitivesTest.scala    From cleanframes   with Apache License 2.0 5 votes vote down vote up
package cleanframes

import org.scalatest.{FlatSpec, Matchers}
import com.holdenkarau.spark.testing.DataFrameSuiteBase


class OptionalPrimitivesTest
  extends FlatSpec
    with Matchers
    with DataFrameSuiteBase {

  "Cleaner" should "transform data to concrete types if possible" in {
    import spark.implicits._ // to use `.toDF` and `.as`
    import cleanframes.syntax._ // to use `.clean`

    // define test data for a dataframe
    val input = Seq(
      // @formatter:off
      ("1",         "1",          "1",          "1",          "1",          "1",          "true"),
      ("corrupted", "2",          "2",          "2",          "2",          "2",          "false"),
      ("3",         "corrupted",  "3",          "3",          "3",          "3",          null),
      ("4",         "4",          "corrupted",  "4",          "4",          "4",          "true"),
      ("5",         "5",          "5",          "corrupted",  "5",          "5",          "false"),
      ("6",         "6",          "6",          "6",          "corrupted",  "6",          null),
      ("7",         "7",          "7",          "7",          "7",          "corrupted",  "true"),
      ("8",         "8",          "8",          "8",          "8",          "8",          "corrupted")
      // @formatter:on
    )
      // important! dataframe's column names must match parameter names of the case class passed to `.clean` method
      .toDF("col1", "col2", "col3", "col4", "col5", "col6", "col7")

    // import standard functions for conversions shipped with the library
    import cleanframes.instances.all._

    val result = input
      // call cleanframes API
      .clean[AnyValsExample]
      // make Dataset
      .as[AnyValsExample]
      .collect

    import cleanframes.{AnyValsExample => Model} // just for readability sake

    result should {
      contain theSameElementsAs Seq(
        // @formatter:off
        Model(Some(1),   Some(1),  Some(1),  Some(1),  Some(1),  Some(1),  Some(true)),
        Model(None,      Some(2),  Some(2),  Some(2),  Some(2),  Some(2),  Some(false)),
        Model(Some(3),   None,     Some(3),  Some(3),  Some(3),  Some(3),  Some(false)),
        Model(Some(4),   Some(4),  None,     Some(4),  Some(4),  Some(4),  Some(true)),
        Model(Some(5),   Some(5),  Some(5),  None,     Some(5),  Some(5),  Some(false)),
        Model(Some(6),   Some(6),  Some(6),  Some(6),  None,     Some(6),  Some(false)),
        Model(Some(7),   Some(7),  Some(7),  Some(7),  Some(7),  None,     Some(true)),
        Model(Some(8),   Some(8),  Some(8),  Some(8),  Some(8),  Some(8),  Some(false))
        // @formatter:on
      )
    }.and(have size 8)

  }

}

case class AnyValsExample(col1: Option[Int],
                          col2: Option[Byte],
                          col3: Option[Short],
                          col4: Option[Long],
                          col5: Option[Float],
                          col6: Option[Double],
                          col7: Option[Boolean]) 
Example 24
Source File: SingleImportInsteadAllTest.scala    From cleanframes   with Apache License 2.0 5 votes vote down vote up
package cleanframes

import org.scalatest.{FlatSpec, Matchers}
import com.holdenkarau.spark.testing.DataFrameSuiteBase


class SingleImportInsteadAllTest
  extends FlatSpec
    with Matchers
    with DataFrameSuiteBase {

  "Cleaner" should "transform data by using concrete import" in {
    import spark.implicits._ // to use `.toDF` and `.as`
    import cleanframes.syntax._ // to use `.clean`

    // define test data for a dataframe
    val input = Seq(
      ("1"),
      ("corrupted"),
      ("3"),
      ("4"),
      ("5"),
      (null),
      ("null"),
      ("     x   "),
      ("     6 2  "),
      ("6"),
      ("7"),
      ("8")
    )
      // important! dataframe's column names must match parameter names of the case class passed to `.clean` method
      .toDF("col1")

    // import standard functions for conversions shipped with the library
    import cleanframes.instances.int._

    val result = input
      // call cleanframes API
      .clean[SingleIntModel]
      // make Dataset
      .as[SingleIntModel]
      .collect

    result should {
      contain theSameElementsAs Seq(
        SingleIntModel(Some(1)),
        SingleIntModel(None),
        SingleIntModel(Some(3)),
        SingleIntModel(Some(4)),
        SingleIntModel(Some(5)),
        SingleIntModel(None),
        SingleIntModel(None),
        SingleIntModel(None),
        SingleIntModel(None),
        SingleIntModel(Some(6)),
        SingleIntModel(Some(7)),
        SingleIntModel(Some(8))
      )
    }
  }

}

case class SingleIntModel(col1: Option[Int]) 
Example 25
Source File: BigQueryClientSpecs.scala    From spark-bigquery   with Apache License 2.0 4 votes vote down vote up
package com.samelamin.spark.bigquery

import java.io.File

import com.google.api.services.bigquery.Bigquery
import com.google.api.services.bigquery.model._
import com.google.cloud.hadoop.io.bigquery._
import com.holdenkarau.spark.testing.DataFrameSuiteBase
import com.samelamin.spark.bigquery.converters.{BigQueryAdapter, SchemaConverters}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql._
import org.mockito.Matchers.{any, eq => mockitoEq}
import org.mockito.Mockito._
import org.scalatest.FeatureSpec
import org.scalatest.mock.MockitoSugar


class BigQueryClientSpecs extends FeatureSpec with DataFrameSuiteBase with MockitoSugar {
  val BQProjectId = "google.com:foo-project"

  def setupBigQueryClient(sqlCtx: SQLContext, bigQueryMock: Bigquery): BigQueryClient = {
    val fakeJobReference = new JobReference()
    fakeJobReference.setProjectId(BQProjectId)
    fakeJobReference.setJobId("bigquery-job-1234")
    val dataProjectId = "publicdata"
    // Create the job result.
    val jobStatus = new JobStatus()
    jobStatus.setState("DONE")
    jobStatus.setErrorResult(null)

    val jobHandle = new Job()
    jobHandle.setStatus(jobStatus)
    jobHandle.setJobReference(fakeJobReference)

    // Create table reference.
    val tableRef = new TableReference()
    tableRef.setProjectId(dataProjectId)
    tableRef.setDatasetId("test_dataset")
    tableRef.setTableId("test_table")

    // Mock getting Bigquery jobs
    when(bigQueryMock.jobs().get(any[String], any[String]).execute())
      .thenReturn(jobHandle)
    when(bigQueryMock.jobs().insert(any[String], any[Job]).execute())
      .thenReturn(jobHandle)

    val bigQueryClient = new BigQueryClient(sqlCtx, bigQueryMock)
    bigQueryClient
  }

  scenario("When writing to BQ") {
    val sqlCtx = sqlContext
    import sqlCtx.implicits._
    val gcsPath = "/tmp/testfile2.json"
    FileUtils.deleteQuietly(new File(gcsPath))
    val adaptedDf = BigQueryAdapter(sc.parallelize(List(1, 2, 3)).toDF)
    val bigQueryMock =  mock[Bigquery](RETURNS_DEEP_STUBS)
    val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test"
    val targetTable = BigQueryStrings.parseTableReference(fullyQualifiedOutputTableId)
    val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock)
    val bigQuerySchema = SchemaConverters.SqlToBQSchema(adaptedDf)

    bigQueryClient.load(targetTable,bigQuerySchema,gcsPath)
    verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute()
  }

  scenario("When reading from BQ") {
    val sqlCtx = sqlContext
    val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test"
    val sqlQuery = s"select * from $fullyQualifiedOutputTableId"

    val bqQueryContext = new BigQuerySQLContext(sqlCtx)
    bqQueryContext.setBigQueryProjectId(BQProjectId)
    val bigQueryMock =  mock[Bigquery](RETURNS_DEEP_STUBS)
    val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock)
    bigQueryClient.selectQuery(sqlQuery)
    verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute()
  }

  scenario("When running a DML Queries") {
    val sqlCtx = sqlContext
    val fullyQualifiedOutputTableId = "testProjectID:test_dataset.test"
    val dmlQuery = s"UPDATE $fullyQualifiedOutputTableId SET test_col = new_value WHERE test_col = old_value"
    val bqQueryContext = new BigQuerySQLContext(sqlCtx)
    bqQueryContext.setBigQueryProjectId(BQProjectId)
    val bigQueryMock =  mock[Bigquery](RETURNS_DEEP_STUBS)
    val bigQueryClient = setupBigQueryClient(sqlCtx, bigQueryMock)
    bigQueryClient.runDMLQuery(dmlQuery)
    verify(bigQueryMock.jobs().insert(mockitoEq(BQProjectId),any[Job]), times(1)).execute()
  }
}