org.apache.spark.sql.types.StructType Scala Examples

The following examples show how to use org.apache.spark.sql.types.StructType. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TextFileFormat.scala    From drizzle-spark   with Apache License 2.0 12 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration


  def getCompressionExtension(context: TaskAttemptContext): String = {
    // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
    if (FileOutputFormat.getCompressOutput(context)) {
      val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
      ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
    } else {
      ""
    }
  }
} 
Example 2
Source File: DefaultSource.scala    From spark-snowflake   with Apache License 2.0 7 votes vote down vote up
package net.snowflake.spark.snowflake

import net.snowflake.spark.snowflake.streaming.SnowflakeSink
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_SHORT_NAME
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.slf4j.LoggerFactory


  override def createRelation(sqlContext: SQLContext,
                              saveMode: SaveMode,
                              parameters: Map[String, String],
                              data: DataFrame): BaseRelation = {

    val params = Parameters.mergeParameters(parameters)
    // check spark version for push down
    if (params.autoPushdown) {
      SnowflakeConnectorUtils.checkVersionAndEnablePushdown(
        sqlContext.sparkSession
      )
    }
    // pass parameters to pushdown functions
    pushdowns.setGlobalParameter(params)
    val table = params.table.getOrElse {
      throw new IllegalArgumentException(
        "For save operations you must specify a Snowfake table name with the 'dbtable' parameter"
      )
    }

    def tableExists: Boolean = {
      val conn = jdbcWrapper.getConnector(params)
      try {
        jdbcWrapper.tableExists(conn, table.toString)
      } finally {
        conn.close()
      }
    }

    val (doSave, dropExisting) = saveMode match {
      case SaveMode.Append => (true, false)
      case SaveMode.Overwrite => (true, true)
      case SaveMode.ErrorIfExists =>
        if (tableExists) {
          sys.error(
            s"Table $table already exists! (SaveMode is set to ErrorIfExists)"
          )
        } else {
          (true, false)
        }
      case SaveMode.Ignore =>
        if (tableExists) {
          log.info(s"Table $table already exists -- ignoring save request.")
          (false, false)
        } else {
          (true, false)
        }
    }

    if (doSave) {
      val updatedParams = parameters.updated("overwrite", dropExisting.toString)
      new SnowflakeWriter(jdbcWrapper)
        .save(
          sqlContext,
          data,
          saveMode,
          Parameters.mergeParameters(updatedParams)
        )

    }

    createRelation(sqlContext, parameters)
  }

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink =
    new SnowflakeSink(sqlContext, parameters, partitionColumns, outputMode)
} 
Example 3
Source File: OnErrorSuite.scala    From spark-snowflake   with Apache License 2.0 6 votes vote down vote up
package net.snowflake.spark.snowflake

import net.snowflake.client.jdbc.SnowflakeSQLException
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class OnErrorSuite extends IntegrationSuiteBase {
  lazy val table = s"spark_test_table_$randomSuffix"

  lazy val schema = new StructType(
    Array(StructField("var", StringType, nullable = false))
  )

  lazy val df: DataFrame = sparkSession.createDataFrame(
    sc.parallelize(
      Seq(Row("{\"dsadas\nadsa\":12311}"), Row("{\"abc\":334}")) // invalid json key
    ),
    schema
  )

  override def beforeAll(): Unit = {
    super.beforeAll()
    jdbcUpdate(s"create or replace table $table(var variant)")
  }

  override def afterAll(): Unit = {
    jdbcUpdate(s"drop table $table")
    super.afterAll()
  }

  test("continue_on_error off") {

    assertThrows[SnowflakeSQLException] {
      df.write
        .format(SNOWFLAKE_SOURCE_NAME)
        .options(connectorOptionsNoTable)
        .option("dbtable", table)
        .mode(SaveMode.Append)
        .save()
    }
  }

  test("continue_on_error on") {
    df.write
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("continue_on_error", "on")
      .option("dbtable", table)
      .mode(SaveMode.Append)
      .save()

    val result = sparkSession.read
      .format(SNOWFLAKE_SOURCE_NAME)
      .options(connectorOptionsNoTable)
      .option("dbtable", table)
      .load()

    assert(result.collect().length == 1)
  }

} 
Example 4
Source File: OrcFileOperator.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.sql.hive.orc

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader}
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.StructType

private[orc] object OrcFileOperator extends Logging {
  
  def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = {
    def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = {
      reader.getObjectInspector match {
        case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 =>
          logInfo(
            s"ORC file $path has empty schema, it probably contains no rows. " +
              "Trying to read another ORC file to figure out the schema.")
          false
        case _ => true
      }
    }

    val conf = config.getOrElse(new Configuration)
    val fs = {
      val hdfsPath = new Path(basePath)
      hdfsPath.getFileSystem(conf)
    }

    listOrcFiles(basePath, conf).iterator.map { path =>
      path -> OrcFile.createReader(fs, path)
    }.collectFirst {
      case (path, reader) if isWithNonEmptySchema(path, reader) => reader
    }
  }

  def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = {
    // Take the first file where we can open a valid reader if we can find one.  Otherwise just
    // return None to indicate we can't infer the schema.
    paths.flatMap(getFileReader(_, conf)).headOption.map { reader =>
      val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector]
      val schema = readerInspector.getTypeName
      logDebug(s"Reading schema from file $paths, got Hive schema string: $schema")
      CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType]
    }
  }

  def getObjectInspector(
      path: String, conf: Option[Configuration]): Option[StructObjectInspector] = {
    getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector])
  }

  def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
    // TODO: Check if the paths coming in are already qualified and simplify.
    val origPath = new Path(pathStr)
    val fs = origPath.getFileSystem(conf)
    val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
      .filterNot(_.isDirectory)
      .map(_.getPath)
      .filterNot(_.getName.startsWith("_"))
      .filterNot(_.getName.startsWith("."))
    paths
  }
} 
Example 5
Source File: KustoCsvSerializationUtils.scala    From azure-kusto-spark   with Apache License 2.0 6 votes vote down vote up
package com.microsoft.kusto.spark.datasink

import java.util.TimeZone

import com.microsoft.kusto.spark.utils.DataTypeMapping
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.DataTypes._
import org.apache.spark.sql.types.StructType

private[kusto] class KustoCsvSerializationUtils (val schema: StructType, timeZone: String){
  private[kusto] val dateFormat = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", TimeZone.getTimeZone(timeZone))

  private[kusto] def convertRow(row: InternalRow) = {
    val values = new Array[String](row.numFields)
    for (i <- 0 until row.numFields if !row.isNullAt(i))
    {
      val dataType = schema.fields(i).dataType
      values(i) = dataType match {
          case DateType => DateTimeUtils.toJavaDate(row.getInt(i)).toString
          case TimestampType => dateFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(i)))
          case _ => row.get(i, dataType).toString
        }
    }

    values
  }
}

private[kusto] object KustoCsvMapper {
    import org.apache.spark.sql.types.StructType
    import org.json

    def createCsvMapping(schema: StructType): String = {
      val csvMapping = new json.JSONArray()

      for (i <- 0 until schema.length)
      {
        val field = schema.apply(i)
        val dataType = field.dataType
        val mapping = new json.JSONObject()
        mapping.put("Name", field.name)
        mapping.put("Ordinal", i)
        mapping.put("DataType", DataTypeMapping.sparkTypeToKustoTypeMap.getOrElse(dataType, StringType))

        csvMapping.put(mapping)
      }

      csvMapping.toString
    }
  } 
Example 6
Source File: MNISTBenchmark.scala    From spark-knn   with Apache License 2.0 6 votes vote down vote up
package com.github.saurfang.spark.ml.knn.examples

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.classification.{KNNClassifier, NaiveKNNClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.tuning.{Benchmarker, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.log4j

import scala.collection.mutable


object MNISTBenchmark {

  val logger = log4j.Logger.getLogger(getClass)

  def main(args: Array[String]) {
    val ns = if(args.isEmpty) (2500 to 10000 by 2500).toArray else args(0).split(',').map(_.toInt)
    val path = if(args.length >= 2) args(1) else "data/mnist/mnist.bz2"
    val numPartitions = if(args.length >= 3) args(2).toInt else 10
    val models = if(args.length >=4) args(3).split(',') else Array("tree","naive")

    val spark = SparkSession.builder().getOrCreate()
    val sc = spark.sparkContext
    import spark.implicits._

    //read in raw label and features
    val rawDataset = MLUtils.loadLibSVMFile(sc, path)
      .zipWithIndex()
      .filter(_._2 < ns.max)
      .sortBy(_._2, numPartitions = numPartitions)
      .keys
      .toDF()

    // convert "features" from mllib.linalg.Vector to ml.linalg.Vector
    val dataset =  MLUtils.convertVectorColumnsToML(rawDataset)
      .cache()
    dataset.count() //force persist

    val limiter = new Limiter()
    val knn = new KNNClassifier()
      .setTopTreeSize(numPartitions * 10)
      .setFeaturesCol("features")
      .setPredictionCol("prediction")
      .setK(1)
    val naiveKNN = new NaiveKNNClassifier()

    val pipeline = new Pipeline()
      .setStages(Array(limiter, knn))
    val naivePipeline = new Pipeline()
      .setStages(Array(limiter, naiveKNN))

    val paramGrid = new ParamGridBuilder()
      .addGrid(limiter.n, ns)
      .build()

    val bm = new Benchmarker()
      .setEvaluator(new MulticlassClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumTimes(3)

    val metrics = mutable.ArrayBuffer[String]()
    if(models.contains("tree")) {
      val bmModel = bm.setEstimator(pipeline).fit(dataset)
      metrics += s"knn: ${bmModel.avgTrainingRuntimes.toSeq} / ${bmModel.avgEvaluationRuntimes.toSeq}"
    }
    if(models.contains("naive")) {
      val naiveBMModel = bm.setEstimator(naivePipeline).fit(dataset)
      metrics += s"naive: ${naiveBMModel.avgTrainingRuntimes.toSeq} / ${naiveBMModel.avgEvaluationRuntimes.toSeq}"
    }
    logger.info(metrics.mkString("\n"))
  }
}

class Limiter(override val uid: String) extends Transformer {
  def this() = this(Identifiable.randomUID("limiter"))

  val n: IntParam = new IntParam(this, "n", "number of rows to limit")

  def setN(value: Int): this.type = set(n, value)

  // hack to maintain number of partitions (otherwise it collapses to 1 which is unfair for naiveKNN)
  override def transform(dataset: Dataset[_]): DataFrame = dataset.limit($(n)).repartition(dataset.rdd.partitions.length).toDF()

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = schema
} 
Example 7
Source File: BatchEvalPythonExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}


case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends EvalPythonExec(udfs, output, child) {

  protected override def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow] = {
    EvaluatePython.registerPicklers()  // register pickler for Row

    val dataTypes = schema.map(_.dataType)
    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

    // enable memo iff we serialize the row with schema (schema and class should be memorized)
    val pickle = new Pickler(needConversion)
    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
    // For each row, add it to the queue.
    val inputIterator = iter.map { row =>
      if (needConversion) {
        EvaluatePython.toJava(row, schema)
      } else {
        // fast path for these types that does not need conversion in Python
        val fields = new Array[Any](row.numFields)
        var i = 0
        while (i < row.numFields) {
          val dt = dataTypes(i)
          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
          i += 1
        }
        fields
      }
    }.grouped(100).map(x => pickle.dumps(x.toArray))

    // Output iterator for results from Python.
    val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
      .compute(inputIterator, context.partitionId(), context)

    val unpickle = new Unpickler
    val mutableRow = new GenericInternalRow(1)
    val resultType = if (udfs.length == 1) {
      udfs.head.dataType
    } else {
      StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
    }

    val fromJava = EvaluatePython.makeFromJava(resultType)

    outputIterator.flatMap { pickedResult =>
      val unpickledBatch = unpickle.loads(pickedResult)
      unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
    }.map { result =>
      if (udfs.length == 1) {
        // fast path for single UDF
        mutableRow(0) = fromJava(result)
        mutableRow
      } else {
        fromJava(result).asInstanceOf[InternalRow]
      }
    }
  }
} 
Example 8
Source File: MapPartitionsRWrapper.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.r

import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.types.StructType


case class MapPartitionsRWrapper(
    func: Array[Byte],
    packageNames: Array[Byte],
    broadcastVars: Array[Broadcast[Object]],
    inputSchema: StructType,
    outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
  def apply(iter: Iterator[Any]): Iterator[Any] = {
    // If the content of current DataFrame is serialized R data?
    val isSerializedRData = inputSchema == SERIALIZED_R_DATA_SCHEMA

    val (newIter, deserializer, colNames) =
      if (!isSerializedRData) {
        // Serialize each row into a byte array that can be deserialized in the R worker
        (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
         SerializationFormats.ROW, inputSchema.fieldNames)
      } else {
        (iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
      }

    val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
      SerializationFormats.ROW
    } else {
      SerializationFormats.BYTE
    }

    val runner = new RRunner[Array[Byte]](
      func, deserializer, serializer, packageNames, broadcastVars,
      isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
    // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
    val outputIter = runner.compute(newIter, -1)

    if (serializer == SerializationFormats.ROW) {
      outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
    } else {
      outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
    }
  }
} 
Example 9
Source File: subquery.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 10
Source File: resources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 11
Source File: MetadataLogFileIndex.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.collection.mutable

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.StructType



class MetadataLogFileIndex(
    sparkSession: SparkSession,
    path: Path,
    userSpecifiedSchema: Option[StructType])
  extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) {

  private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
  logInfo(s"Reading streaming file log from $metadataDirectory")
  private val metadataLog =
    new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString)
  private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory)
  private var cachedPartitionSpec: PartitionSpec = _

  override protected val leafFiles: mutable.LinkedHashMap[Path, FileStatus] = {
    new mutable.LinkedHashMap ++= allFilesFromLog.map(f => f.getPath -> f)
  }

  override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = {
    allFilesFromLog.groupBy(_.getPath.getParent)
  }

  override def rootPaths: Seq[Path] = path :: Nil

  override def refresh(): Unit = { }

  override def partitionSpec(): PartitionSpec = {
    if (cachedPartitionSpec == null) {
      cachedPartitionSpec = inferPartitioning()
    }
    cachedPartitionSpec
  }
} 
Example 12
Source File: StreamingGlobalLimitExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.streaming.state.StateStoreOps
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
import org.apache.spark.util.CompletionIterator


case class StreamingGlobalLimitExec(
    streamLimit: Long,
    child: SparkPlan,
    stateInfo: Option[StatefulOperatorStateInfo] = None,
    outputMode: Option[OutputMode] = None)
  extends UnaryExecNode with StateStoreWriter {

  private val keySchema = StructType(Array(StructField("key", NullType)))
  private val valueSchema = StructType(Array(StructField("value", LongType)))

  override protected def doExecute(): RDD[InternalRow] = {
    metrics // force lazy init at driver

    assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append,
      "StreamingGlobalLimitExec is only valid for streams in Append output mode")

    child.execute().mapPartitionsWithStateStore(
        getStateInfo,
        keySchema,
        valueSchema,
        indexOrdinal = None,
        sqlContext.sessionState,
        Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
      val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
      val numOutputRows = longMetric("numOutputRows")
      val numUpdatedStateRows = longMetric("numUpdatedStateRows")
      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
      val commitTimeMs = longMetric("commitTimeMs")
      val updatesStartTimeNs = System.nanoTime

      val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L)
      var cumulativeRowCount = preBatchRowCount

      val result = iter.filter { r =>
        val x = cumulativeRowCount < streamLimit
        if (x) {
          cumulativeRowCount += 1
        }
        x
      }

      CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
        if (cumulativeRowCount > preBatchRowCount) {
          numUpdatedStateRows += 1
          numOutputRows += cumulativeRowCount - preBatchRowCount
          store.put(key, getValueRow(cumulativeRowCount))
        }
        allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
        commitTimeMs += timeTakenMs { store.commit() }
        setStoreMetrics(store)
      })
    }
  }

  override def output: Seq[Attribute] = child.output

  override def outputPartitioning: Partitioning = child.outputPartitioning

  override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil

  private def getValueRow(value: Long): UnsafeRow = {
    UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value)))
  }
} 
Example 13
Source File: console.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
  extends BaseRelation {
  override def schema: StructType = data.schema
}

class ConsoleSinkProvider extends DataSourceV2
  with StreamWriteSupport
  with DataSourceRegister
  with CreatableRelationProvider {

  override def createStreamWriter(
      queryId: String,
      schema: StructType,
      mode: OutputMode,
      options: DataSourceOptions): StreamWriter = {
    new ConsoleWriter(schema, options)
  }

  def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation = {
    // Number of rows to display, by default 20 rows
    val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20)

    // Truncate the displayed data if it is too long, by default it is true
    val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
    data.show(numRowsToShow, isTruncated)

    ConsoleRelation(sqlContext, data)
  }

  def shortName(): String = "console"
} 
Example 14
Source File: StateStoreRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import java.util.UUID

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.EpochTracker
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration


  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val stateStoreProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),
      queryRunId)
    storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq
  }

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),
      queryRunId)

    // If we're in continuous processing mode, we should get the store version for the current
    // epoch rather than the one at planning time.
    val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING))
      .map(_.toBoolean).getOrElse(false)
    val currentVersion = if (isContinuous) {
      val epoch = EpochTracker.getCurrentEpoch
      assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.")
      epoch.get
    } else {
      storeVersion
    }

    store = StateStore.get(
      storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion,
      storeConf, hadoopConfBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
  }
} 
Example 15
Source File: package.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

package object state {

  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {

    
    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
        stateInfo: StatefulOperatorStateInfo,
        keySchema: StructType,
        valueSchema: StructType,
        indexOrdinal: Option[Int],
        sessionState: SessionState,
        storeCoordinator: Option[StateStoreCoordinatorRef])(
        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {

      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
        // Abort the state store in case of error
        TaskContext.get().addTaskCompletionListener[Unit](_ => {
          if (!store.hasCommitted) store.abort()
        })
        cleanedF(store, iter)
      }

      new StateStoreRDD(
        dataRDD,
        wrappedF,
        stateInfo.checkpointLocation,
        stateInfo.queryRunId,
        stateInfo.operatorId,
        stateInfo.storeVersion,
        keySchema,
        valueSchema,
        indexOrdinal,
        sessionState,
        storeCoordinator)
    }
  }
} 
Example 16
Source File: ConsoleWriter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType


class ConsoleWriter(schema: StructType, options: DataSourceOptions)
    extends StreamWriter with Logging {

  // Number of rows to display, by default 20 rows
  protected val numRowsToShow = options.getInt("numRows", 20)

  // Truncate the displayed data if it is too long, by default it is true
  protected val isTruncated = options.getBoolean("truncate", true)

  assert(SparkSession.getActiveSession.isDefined)
  protected val spark = SparkSession.getActiveSession.get

  def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory

  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
    // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
    // behavior.
    printRows(messages, schema, s"Batch: $epochId")
  }

  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

  protected def printRows(
      commitMessages: Array[WriterCommitMessage],
      schema: StructType,
      printMessage: String): Unit = {
    val rows = commitMessages.collect {
      case PackedRowCommitMessage(rs) => rs
    }.flatten

    // scalastyle:off println
    println("-------------------------------------------")
    println(printMessage)
    println("-------------------------------------------")
    // scalastyle:off println
    Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
      .show(numRowsToShow, isTruncated)
  }

  override def toString(): String = {
    s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
  }
} 
Example 17
Source File: ExplainSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class ExplainSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  
  private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = {
    val output = new java.io.ByteArrayOutputStream()
    Console.withOut(output) {
      df.explain(extended = false)
    }
    for (key <- keywords) {
      assert(output.toString.contains(key))
    }
  }

  test("SPARK-23034 show rdd names in RDD scan nodes (Dataset)") {
    val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd")
    val df = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string"))
    checkKeywordsExistsInExplain(df, keywords = "Scan ExistingRDD testRdd")
  }

  test("SPARK-23034 show rdd names in RDD scan nodes (DataFrame)") {
    val rddWithName = spark.sparkContext.parallelize(ExplainSingleData(1) :: Nil).setName("testRdd")
    val df = spark.createDataFrame(rddWithName)
    checkKeywordsExistsInExplain(df, keywords = "Scan testRdd")
  }

  test("SPARK-24850 InMemoryRelation string representation does not include cached plan") {
    val df = Seq(1).toDF("a").cache()
    checkKeywordsExistsInExplain(df,
      keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)")
  }
}

case class ExplainSingleData(id: Int) 
Example 18
Source File: GroupedIteratorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 19
Source File: MemorySinkV2Suite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.types.StructType

class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
  test("data writer") {
    val partition = 1234
    val writer = new MemoryDataWriter(
      partition, OutputMode.Append(), new StructType().add("i", "int"))
    writer.write(InternalRow(1))
    writer.write(InternalRow(2))
    writer.write(InternalRow(44))
    val msg = writer.commit()
    assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
    assert(msg.partition == partition)

    // Buffer should be cleared, so repeated commits should give empty.
    assert(writer.commit().data.isEmpty)
  }

  test("streaming writer") {
    val sink = new MemorySinkV2
    val writeSupport = new MemoryStreamWriter(
      sink, OutputMode.Append(), new StructType().add("i", "int"))
    writeSupport.commit(0,
      Array(
        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
      ))
    assert(sink.latestBatchId.contains(0))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
    writeSupport.commit(19,
      Array(
        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
        MemoryWriterCommitMessage(0, Seq(Row(33)))
      ))
    assert(sink.latestBatchId.contains(19))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))

    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
  }
} 
Example 20
Source File: BlockingSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 21
Source File: MockSourceProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    MockSourceProvider.sourceProviderFunction()
  }
}

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
      source
    }
    try {
      f
    } finally {
      sourceProviderFunction = null
    }
  }
} 
Example 22
Source File: ShowTablesUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.sources.DatasourceCatalog
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.{StringType, StructField, StructType}


private[sql]
case class ShowTablesUsingCommand(provider: String, options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("IS_TEMPORARY", StringType, nullable = false) ::
    StructField("KIND", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val dataSource: Any = DatasourceResolver.resolverFor(sqlContext).newInstanceOf(provider)

    dataSource match {
      case describableRelation: DatasourceCatalog =>
        describableRelation
          .getRelations(sqlContext, new CaseInsensitiveMap(options))
          .map(relationInfo => Row(
            relationInfo.name,
            relationInfo.isTemporary.toString.toUpperCase,
            relationInfo.kind.toUpperCase))
      case _ =>
        throw new RuntimeException(s"The provided data source $provider does not support " +
        "showing its relations.")
    }
  }
} 
Example 23
Source File: DeepDescribeCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.describable.Describable
import org.apache.spark.sql.sources.describable.FieldLike.StructFieldLike
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}


private[sql] case class DeepDescribeCommand(
    relation: Describable)
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val description = relation.describe()
    Seq(description match {
      case r: Row => r
      case default => Row(default)
    })
  }

  override def output: Seq[Attribute] = {
    relation.describeOutput match {
      case StructType(fields) =>
        fields.map(StructFieldLike.toAttribute)
      case other =>
        AttributeReference("value", other)() :: Nil
    }
  }
} 
Example 24
Source File: DescribeTableUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.TableIdentifierUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.{DatasourceCatalog, RelationInfo}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}


private[sql]
case class DescribeTableUsingCommand(
    name: TableIdentifier,
    provider: String,
    options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("DDL_STMT", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    // Convert the table name according to the case-sensitivity settings
    val tableId = name.toSeq
    val resolver = DatasourceResolver.resolverFor(sqlContext)
    val catalog = resolver.newInstanceOfTyped[DatasourceCatalog](provider)

    Seq(catalog
      .getRelation(sqlContext, tableId, new CaseInsensitiveMap(options)) match {
        case None => Row("", "")
        case Some(RelationInfo(relName, _, _, ddl, _)) => Row(
          relName, ddl.getOrElse(""))
    })
  }
} 
Example 25
Source File: RawDDLCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.sources.RawDDLObjectType.RawDDLObjectType
import org.apache.spark.sql.sources.RawDDLStatementType.RawDDLStatementType
import org.apache.spark.sql.sources.{RawSqlSourceProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.execution.RunnableCommand


private[sql] case class RawDDLCommand(
    identifier: String,
    objectType: RawDDLObjectType,
    statementType: RawDDLStatementType,
    sparkSchema: Option[StructType],
    ddlStatement: String,
    provider: String,
    options: Map[String, String])
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val dataSource: Any = ResolvedDataSource.lookupDataSource(provider).newInstance()

    dataSource match {
      case rsp: RawSqlSourceProvider =>
        rsp.executeDDL(identifier, objectType, statementType, sparkSchema, ddlStatement, options)
      case _ => throw new RuntimeException("The provided datasource does not support " +
        "executing raw DDL statements.")
    }
    Seq.empty[Row]
  }
} 
Example 26
Source File: CreateTablePartitionedByUsing.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.types.StructType


case class CreateTablePartitionedByUsing(tableIdent: TableIdentifier,
                                         userSpecifiedSchema:
                                         Option[StructType],
                                         provider: String,
                                         partitioningFunc: String,
                                         partitioningColumns: Seq[String],
                                         temporary: Boolean,
                                         options: Map[String, String],
                                         allowExisting: Boolean,
                                         managedIfNoPath: Boolean)
  extends LogicalPlan with Command {

  override def output: Seq[Attribute] = Seq.empty
  override def children: Seq[LogicalPlan] = Seq.empty
} 
Example 27
Source File: PartitionedRelationProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType


trait PartitionedRelationProvider
  extends SchemaRelationProvider
  with TemporaryAndPersistentNature {

  def createRelation(sqlContext: SQLContext,
                     tableName: Seq[String],
                     parameters: Map[String, String],
                     partitioningFunction: Option[String],
                     partitioningColumns: Option[Seq[String]],
                     isTemporary: Boolean,
                     allowExisting: Boolean): BaseRelation

  def createRelation(sqlContext: SQLContext,
                     tableName: Seq[String],
                     parameters: Map[String, String],
                     schema: StructType,
                     partitioningFunction: Option[String],
                     partitioningColumns: Option[Seq[String]],
                     isTemporary: Boolean,
                     allowExisting: Boolean): BaseRelation

} 
Example 28
Source File: RawSqlSourceProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.{PhysicalRDD, RDDConversions, SparkPlan}
import org.apache.spark.sql.sources.RawDDLObjectType.RawDDLObjectType
import org.apache.spark.sql.sources.RawDDLStatementType.RawDDLStatementType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}

case object RawDDLObjectType {

  sealed trait RawDDLObjectType {
    val name: String
    override def toString: String = name
  }

  sealed abstract class BaseRawDDLObjectType(val name: String) extends RawDDLObjectType
  sealed trait RawData

  case object PartitionFunction extends BaseRawDDLObjectType("partition function")
  case object PartitionScheme   extends BaseRawDDLObjectType("partition scheme")
  case object Collection        extends BaseRawDDLObjectType("collection") with RawData
  case object Series            extends BaseRawDDLObjectType("table") with RawData
  case object Graph             extends BaseRawDDLObjectType("graph") with RawData
}

case object RawDDLStatementType {

  sealed trait RawDDLStatementType

  case object Create extends RawDDLStatementType
  case object Drop   extends RawDDLStatementType
  case object Append extends RawDDLStatementType
  case object Load   extends RawDDLStatementType
}


  protected def calculateSchema(): StructType
} 
Example 29
Source File: dependenciesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableDependencyCalculator
import org.apache.spark.sql.sources.{RelationKind, Table}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object DependenciesSystemTableProvider extends SystemTableProvider with LocalSpark {
  
  override def execute(): Seq[Row] = {
    val tables = getTables(sqlContext.catalog)
    val dependentsMap = buildDependentsMap(tables)

    def kindOf(tableIdentifier: TableIdentifier): String =
      tables
        .get(tableIdentifier)
        .map(plan => RelationKind.kindOf(plan).getOrElse(Table).name)
        .getOrElse(DependenciesSystemTable.UnknownType)
        .toUpperCase

    dependentsMap.flatMap {
      case (tableIdent, dependents) =>
        val curKind = kindOf(tableIdent)
        dependents.map { dependent =>
          val dependentKind = kindOf(dependent)
          Row(
            tableIdent.database.orNull,
            tableIdent.table,
            curKind,
            dependent.database.orNull,
            dependent.table,
            dependentKind,
            ReferenceDependency.id)
        }
    }.toSeq
  }

  override val schema: StructType = DependenciesSystemTable.schema
}

object DependenciesSystemTable extends SchemaEnumeration {
  val baseSchemaName = Field("BASE_SCHEMA_NAME", StringType, nullable = true)
  val baseObjectName = Field("BASE_OBJECT_NAME", StringType, nullable = false)
  val baseObjectType = Field("BASE_OBJECT_TYPE", StringType, nullable = false)
  val dependentSchemaName = Field("DEPENDENT_SCHEMA_NAME", StringType, nullable = true)
  val dependentObjectName = Field("DEPENDENT_OBJECT_NAME", StringType, nullable = false)
  val dependentObjectType = Field("DEPENDENT_OBJECT_TYPE", StringType, nullable = false)
  val dependencyType = Field("DEPENDENCY_TYPE", IntegerType, nullable = false)

  private[DependenciesSystemTable] val UnknownType = "UNKNOWN"
} 
Example 30
Source File: partitionFunctionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.GenericUtil._


  private def typeNameOf(f: PartitionFunction): String = f match {
    case _: RangePartitionFunction => "RANGE"
    case _: BlockPartitionFunction => "BLOCK"
    case _: HashPartitionFunction => "HASH"
  }
}

object PartitionFunctionSystemTable extends SchemaEnumeration {
  val id = Field("ID", StringType, nullable = false)
  val functionType = Field("TYPE", StringType, nullable = false)
  val columnName = Field("COLUMN_NAME", StringType, nullable = false)
  val columnType = Field("COLUMN_TYPE", StringType, nullable = false)
  val boundaries = Field("BOUNDARIES", StringType, nullable = true)
  val block = Field("BLOCK_SIZE", IntegerType, nullable = true)
  val partitions = Field("PARTITIONS", IntegerType, nullable = true)
  val minP = Field("MIN_PARTITIONS", IntegerType, nullable = true)
  val maxP = Field("MAX_PARTITIONS", IntegerType, nullable = true)
} 
Example 31
Source File: sessionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLConf, SQLContext}


  private def allSettingsOf(conf: SQLConf): Map[String, String] = {
    val setConfs = conf.getAllConfs
    val defaultConfs = conf.getAllDefinedConfs.collect {
      case (key, default, _) if !setConfs.contains(key) => key -> default
    }
    setConfs ++ defaultConfs
  }

  override def schema: StructType = SessionSystemTable.schema
}

object SessionSystemTable extends SchemaEnumeration {
  val section = Field("SECTION", StringType, nullable = false)
  val key = Field("KEY", StringType, nullable = false)
  val value = Field("VALUE", StringType, nullable = true)
} 
Example 32
Source File: tablesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.commands.WithOrigin
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.util.CollectionUtils.CaseInsensitiveMap
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._

object TablesSystemTableProvider extends SystemTableProvider with LocalSpark with ProviderBound {
  
  override def buildScan(requiredColumns: Array[String],
                         filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[DatasourceCatalog](provider) match {
      case catalog: DatasourceCatalog with DatasourceCatalogPushDown =>
        catalog.getRelations(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog: DatasourceCatalog =>
        val values =
          catalog
            .getRelations(sqlContext, new CaseInsensitiveMap(options))
            .map(relationInfo => Row(
              relationInfo.name,
              relationInfo.isTemporary.toString.toUpperCase,
              relationInfo.kind.toUpperCase,
              relationInfo.provider))
        val rows = schema.buildPrunedFilteredScan(requiredColumns, filters)(values)
        sparkContext.parallelize(rows)
    }
}

sealed trait TablesSystemTable extends SystemTable {
  override def schema: StructType = TablesSystemTable.schema
}

object TablesSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val isTemporary = Field("IS_TEMPORARY", StringType, nullable = false)
  val kind = Field("KIND", StringType, nullable = false)
  val provider = Field("PROVIDER", StringType, nullable = true)
} 
Example 33
Source File: metadataSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._


  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[MetadataCatalog](provider) match {
      case catalog: MetadataCatalog with MetadataCatalogPushDown =>
        catalog.getTableMetadata(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog =>
        val rows = catalog.getTableMetadata(sqlContext, options).flatMap { tableMetadata =>
          val formatter = new OutputFormatter(tableMetadata.tableName, tableMetadata.metadata)
          formatter.format().map(Row.fromSeq)
        }
        sparkContext.parallelize(schema.buildPrunedFilteredScan(requiredColumns, filters)(rows))
    }

  override def schema: StructType = MetadataSystemTable.schema
}

object MetadataSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val metadataKey = Field("METADATA_KEY", StringType, nullable = true)
  val metadataValue = Field("METADATA_VALUE", StringType, nullable = true)
} 
Example 34
Source File: ScanAndFilterImplicits.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.{And, Filter, FilterUtils}
import org.apache.spark.sql.types.StructType
import FilterUtils._


        values.foldLeft(Seq.empty[Row]) {
          case (acc, value) =>
            val scanned = scanFunction(value)
            if (validation(scanned)) {
              acc :+ scanned
            } else acc
        }
    }
  }

}

object ScanAndFilterImplicits extends ScanAndFilterImplicits 
Example 35
Source File: relationMappingSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object RelationMappingSystemTableProvider extends SystemTableProvider with LocalSpark {

  
  override def execute(): Seq[Row] = {
    sqlContext.tableNames().map { tableName =>
      val plan = sqlContext.catalog.lookupRelation(TableIdentifier(tableName))
      val sqlName = plan.collectFirst {
        case s: SqlLikeRelation =>
          s.relationName
        case LogicalRelation(s: SqlLikeRelation, _) =>
          s.relationName
      }
      Row(tableName, sqlName)
    }
  }
}

object RelationMappingSystemTable extends SchemaEnumeration {
  val sparkName = Field("RELATION_NAME", StringType, nullable = false)
  val providerName = Field("SQL_NAME", StringType, nullable = true)
} 
Example 36
Source File: CaseSensitivityUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CollectionUtils._

import scala.util.{Failure, Success, Try}


  case class DuplicateFieldsException(
      originalSchema: StructType,
      schema: StructType,
      duplicateFields: Set[String])
    extends RuntimeException(
      s"""Given schema contains duplicate fields after applying case sensitivity rules:
         |${duplicateFields.mkString(", ")}
         |Given schema:
         |$originalSchema
         |After applying case sensitivity rules:
         |$schema
       """.stripMargin)
} 
Example 37
Source File: KPISmokeSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import com.sap.commons.TimingTestUtils
import com.sap.spark.util.TestUtils._
import org.apache.spark.util.DummyRelationUtils._
import org.apache.spark.sql.types.StructType
import org.scalatest.FunSuite



  test("Create/drop does not change with number of runs") {
    val sampleSize = 40
    val warmup = 10
    def action(): Unit = {
      createTestTable()
      dropTestTable()
    }
    val (res, corr, samples) =
      TimingTestUtils.executionTimeNotCorrelatedWithRuns(
        acceptedCorrelation, warmup, sampleSize)(action)
    assert(res.booleanValue(), s"Correlation check failed. Correlation is $corr. " +
      s"Accepted correlation is $acceptedCorrelation. Determined samples: $samples")
  }
} 
Example 38
Source File: HiveEmulationSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.hive.SapHiveContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{AnalysisException, GlobalSapSQLContext, Row, SapSQLConf}
import org.apache.spark.util.DummyRelationUtils._
import org.apache.spark.util.SqlContextConfigurationUtils
import org.scalatest.FunSuite

class HiveEmulationSuite
  extends FunSuite
  with GlobalSapSQLContext
  with SqlContextConfigurationUtils {

  private def createTable(name: String, schema: StructType): Unit =
    sqlc.createDataFrame(sc.parallelize(Seq.empty[Row]), schema).registerTempTable(name)

  private def withHiveEmulation[A](op: => A): A =
    withConf(SapSQLConf.HIVE_EMULATION.key, "true")(op)

  test("Show schemas shows a default schema when hive emulation is on") {
    withHiveEmulation {
      val values = sqlc.sql("SHOW SCHEMAS").collect().toSet

      assertResult(Set(Row("default")))(values)
    }
  }

  test("Show schemas throws if hive emulation is off") {
    intercept[RuntimeException](sqlc.sql("SHOW SCHEMAS"))
  }

  test("Desc an existing table") {
    withHiveEmulation {
      createTable("foo", StructType('a.int :: 'b.int :: Nil))
      val values = sqlc.sql("DESC foo").collect().toSet

      assertResult(
        Set(
          Row("a", "int", null),
          Row("b", "int", null)))(values)
    }
  }

  test("Desc a non-existent table throws") {
    withHiveEmulation {
      intercept[NoSuchTableException] {
        sqlc.sql("DESC bar").collect()
      }
    }
  }

  test("Describe an existing table") {
    withHiveEmulation {
      createTable("foo", StructType('a.int :: 'b.int :: Nil))
      val values = sqlc.sql("DESCRIBE FORMATTED foo").collect().toList
      assertResult(
        List(
          Row(s"# col_name${" " * 12}\tdata_type${" " * 11}\tcomment${" " * 13}\t"),
          Row(""),
          Row(s"a${" " * 19}\tint${" " * 17}\tnull${" " * 16}\t"),
          Row(s"b${" " * 19}\tint${" " * 17}\tnull${" " * 16}\t")))(values)
    }
  }

  test("Retrieval of a database prefixed table") {
    val hc = new SapHiveContext(sc)
    hc.setConf(SapSQLConf.HIVE_EMULATION, true)
    val expected = Set(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1))
    val rdd = hc.sparkContext.parallelize(expected.toSeq)
    hc.createDataFrame(rdd, StructType('a.int :: 'b.int :: Nil)).registerTempTable("foo")

    val results = hc.sql("SELECT * FROM default.foo").collect().toSet
    assertResult(expected)(results)

    hc.setConf(SapSQLConf.HIVE_EMULATION, false)
    intercept[AnalysisException] {
      hc.sql("SELECT * FROM default.foo")
    }
  }

  test("USE statements should not do anything when in hive emulation mode") {
    withConf(SapSQLConf.HIVE_EMULATION.key, "true") {
      sqlc.sql("USE foo bar")
    }
  }

  test("Any other use command should throw an exception") {
    intercept[RuntimeException] {
      sqlc.sql("USE foo bar")
    }
  }
} 
Example 39
Source File: CollapseExpandSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.analysis.CollapseExpandSuite.SqlLikeCatalystSourceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.sources.{BaseRelation, CatalystSource, Table}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.util.PlanComparisonUtils._
import org.apache.spark.sql.{GlobalSapSQLContext, Row}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar


class CollapseExpandSuite extends FunSuite with MockitoSugar with GlobalSapSQLContext {
  case object Leaf extends LeafNode {
    override def output: Seq[Attribute] = Seq.empty
  }

  test("Expansion with a single sequence of projections is correctly collapsed") {
    val expand =
      Expand(
        Seq(Seq('a.string, Literal(1))),
        Seq('a.string, 'gid.int),
        Leaf)

    val collapsed = CollapseExpand(expand)
    assertResult(normalizeExprIds(Project(Seq('a.string, Literal(1) as "gid"), Leaf)))(
      normalizeExprIds(collapsed))
  }

  test("Expansion with multiple projections is correctly collapsed") {
    val expand =
      Expand(
        Seq(
          Seq('a.string, Literal(1)),
          Seq('b.string, Literal(1))),
        Seq('a.string, 'gid1.int, 'b.string, 'gid2.int),
        Leaf)

    val collapsed = CollapseExpand(expand)
    assertResult(
      normalizeExprIds(
        Project(Seq(
            'a.string,
            Literal(1) as "gid1",
            'b.string,
            Literal(1) as "gid2"),
          Leaf)))(normalizeExprIds(collapsed))
  }

  test("Expand pushdown integration") {
    val relation = mock[SqlLikeCatalystSourceRelation]
    when(relation.supportsLogicalPlan(any[Expand]))
      .thenReturn(true)
    when(relation.isMultiplePartitionExecution(any[Seq[CatalystSource]]))
      .thenReturn(true)
    when(relation.schema)
      .thenReturn(StructType(StructField("foo", StringType) :: Nil))
    when(relation.relationName)
      .thenReturn("t")
    when(relation.logicalPlanToRDD(any[LogicalPlan]))
      .thenReturn(sc.parallelize(Seq(Row("a", 1), Row("b", 1), Row("a", 1))))

    sqlc.baseRelationToDataFrame(relation).registerTempTable("t")

    val dataFrame = sqlc.sql("SELECT COUNT(DISTINCT foo) FROM t")
    val Seq(Row(ct)) = dataFrame.collect().toSeq

    assertResult(2)(ct)
  }
}

object CollapseExpandSuite {
  abstract class SqlLikeCatalystSourceRelation
    extends BaseRelation
    with Table
    with SqlLikeRelation
    with CatalystSource
} 
Example 40
Source File: ResolveCountDistinctStarSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FunSuite
import org.scalatest.Inside._
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}

import scala.collection.mutable.ArrayBuffer

class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar {
  val persons = new LogicalRelation(new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("age", IntegerType),
      StructField("name", StringType)
    ))
  })

  test("Count distinct star is resolved correctly") {
    val projection = persons.select(UnresolvedAlias(
      AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true)))
    val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection)
    val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer)
                              .apply(stillNotCompletelyResolvedAggregate)
    inside(resolvedAggregate) {
      case Aggregate(Nil,
      ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) =>
        assert(expressions.collect {
          case a:AttributeReference => a.name
        }.toSet == Set("name", "age"))
    }
    assert(resolvedAggregate.resolved)
  }
} 
Example 41
Source File: HiveSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import com.sap.spark.{GlobalSparkContext, WithSapHiveContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.SapHiveContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.scalatest.FunSuite


class HiveSuite
  extends FunSuite
  with GlobalSparkContext
  with WithSapHiveContext {

  val schema = StructType(
    StructField("foo", StringType) ::
    StructField("bar", StringType) :: Nil)

  test("NewSession returns a new SapHiveContext") {
    val hiveContext = sqlc.asInstanceOf[SapHiveContext]
    val newHiveContext = hiveContext.newSession()

    assert(newHiveContext.isInstanceOf[SapHiveContext])
    assert(newHiveContext != hiveContext)
  }

  test("NewSession returns a hive context whose catalog is separated to the current one") {
    val newContext = sqlc.newSession()
    val emptyRdd = newContext.createDataFrame(sc.emptyRDD[Row], schema)
    emptyRdd.registerTempTable("foo")

    assert(!sqlc.tableNames().contains("foo"))
    assert(newContext.tableNames().contains("foo"))
  }
} 
Example 42
Source File: DummyRelationUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{ColumnName, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.types.{StructField, StructType}


  case class DummyCatalystSourceRelation(
      schema: StructType,
      isMultiplePartitionExecutionFunc: Option[Seq[CatalystSource] => Boolean] = None,
      supportsLogicalPlanFunc: Option[LogicalPlan => Boolean] = None,
      supportsExpressionFunc: Option[Expression => Boolean] = None,
      logicalPlanToRDDFunc: Option[LogicalPlan => RDD[Row]] = None)
     (@transient implicit val sqlContext: SQLContext)
    extends BaseRelation
    with CatalystSource {

    override def isMultiplePartitionExecution(relations: Seq[CatalystSource]): Boolean =
      isMultiplePartitionExecutionFunc.forall(_.apply(relations))

    override def supportsLogicalPlan(plan: LogicalPlan): Boolean =
      supportsLogicalPlanFunc.forall(_.apply(plan))

    override def supportsExpression(expr: Expression): Boolean =
      supportsExpressionFunc.forall(_.apply(expr))

    override def logicalPlanToRDD(plan: LogicalPlan): RDD[Row] =
      logicalPlanToRDDFunc.getOrElse(
        (plan: LogicalPlan) => new LogicalPlanRDD(plan, sqlContext.sparkContext)).apply(plan)
  }
} 
Example 43
Source File: CatalystSourceAndDatasourceTestSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark

import org.apache.spark.sql.execution.tablefunctions.TPCHTables
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{GlobalSapSQLContext, Row}
import org.apache.spark.util.DummyRelationUtils._
import org.scalatest.FunSuite
import com.sap.spark.util.TestUtils._


// scalastyle:off magic.number
class CatalystSourceAndDatasourceTestSuite
  extends FunSuite
  with GlobalSapSQLContext {

  test("Select with group by (bug 116823)") {
    val rdd = sc.parallelize(Seq(Row("a", 1L)))
    registerMockCatalystRelation(
      tableName = "foo",
      schema = StructType('a1.double :: 'a2.int :: 'a3.string :: Nil),
      data = rdd)

    val df = sqlc.sql("SELECT a3 AS MYALIAS, COUNT(a1) FROM foo GROUP BY a3")

    assertResult(Array(Row("a", 1)))(df.collect())
  }

  test("Select with group by and having (bug 116824)") {
    val rdd = sc.parallelize(Seq(Row(900L, "a"), Row(101L, "a"), Row(1L, "b")), numSlices = 3)
    val ordersSchema = TPCHTables(sqlc).ordersSchema
    registerMockCatalystRelation("ORDERS", ordersSchema, rdd)

    val df = sqlc.sql(
      """SELECT O_ORDERSTATUS, count(O_ORDERKEY) AS NumberOfOrders
      |FROM ORDERS
      |GROUP BY O_ORDERSTATUS
      |HAVING count(O_ORDERKEY) > 1000""".stripMargin)

    assertResult(Seq(Row("a", 1001L)))(df.collect().toSeq)
  }

  test("Average pushdown"){
    val rdd = sc.parallelize(
      Seq(Row("name1", 20.0, 10L), Row("name2", 10.0, 10L),
          Row("name1", 10.0, 10L), Row("name2", 20.0, 10L)),
      numSlices = 2)
    registerMockCatalystRelation("persons", StructType('name.string :: 'age.int :: Nil), rdd)

    val result =
      sqlContext.sql(s"SELECT name, avg(age) FROM persons GROUP BY name").collect().toSet

    assertResult(Set(Row("name1", 1.5), Row("name2", 1.5)))(result)
 }

  test("Nested query") {
    val rdd = sc.parallelize(Seq(Row(5), Row(5), Row(5), Row(1)), numSlices = 2)
    registerMockCatalystRelation(
      tableName = "fourColumns",
      StructType(('a' to 'e').map(char => Symbol(char.toString).int)),
      data = rdd)

    val result = sqlContext.sql(s"SELECT COUNT(*) FROM (SELECT e,sum(d) " +
      s"FROM fourColumns GROUP BY e) as A").collect().toSet

    assertResult(Set(Row(2)))(result)
  }
} 
Example 44
Source File: TestUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.util

import java.util.Locale

import scala.io.Source
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{Row, SQLContext, SapSQLContext}
import org.apache.spark.sql.hive.SapHiveContext
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.sources.{BaseRelation, CatalystSource, Table}
import org.apache.spark.sql.types.StructType
import org.mockito.Matchers._
import org.mockito.Mockito._

import scala.tools.nsc.io.Directory
import scala.util.{Failure, Success}


  def parsePTestFile(fileName: String): List[(String, String, String)] = {
    val filePath = getFileFromClassPath(fileName)
    val fileContents = Source.fromFile(filePath).getLines
      .map(p => p.stripMargin.trim)
      .filter(p => !p.isEmpty && !p.startsWith("//")) // filter empty rows and comments
      .mkString("\n")
    val p = new PTestFileParser

    // strip semicolons from query and parsed
    p(fileContents) match {
      case Success(lines) =>
        lines.map {
          case (query, parsed, expect) =>
            (stripSemicolon(query).trim, stripSemicolon(parsed).trim, expect.trim)
        }
      case Failure(ex) => throw ex
    }
  }

  private def stripSemicolon(sql: String): String =
    if (sql.endsWith(";")) {
      sql.substring(0, sql.length-1)
    } else {
      sql
    }

  def withTempDirectory[A](f: Directory => A): A = {
    val dir = Directory.makeTemp()
    try {
      f(dir)
    } finally {
      dir.deleteIfExists()
    }
  }
} 
Example 45
Source File: AvroTransformer.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.avro

import com.memsql.spark.etl.api.{UserTransformConfig, Transformer, PhaseConfig}
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType

import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.avro.io.DecoderFactory
import org.apache.avro.specific.SpecificDatumReader

// Takes DataFrames of byte arrays, where each row is a serialized Avro record.
// Returns DataFrames of deserialized data, where each field has its own column.
class AvroTransformer extends Transformer {
  var avroSchemaStr: String = null
  var sparkSqlSchema: StructType = null

  def AvroRDDToDataFrame(sqlContext: SQLContext, rdd: RDD[Row]): DataFrame = {

    val rowRDD: RDD[Row] = rdd.mapPartitions({ partition => {
      // Create per-partition copies of non-serializable objects
      val parser: Schema.Parser = new Schema.Parser()
      val avroSchema = parser.parse(avroSchemaStr)
      val reader = new SpecificDatumReader[GenericData.Record](avroSchema)

      partition.map({ rowOfBytes =>
        val bytes = rowOfBytes(0).asInstanceOf[Array[Byte]]
        val decoder = DecoderFactory.get().binaryDecoder(bytes, null)
        val record = reader.read(null, decoder)
        val avroToRow = new AvroToRow()

        avroToRow.getRow(record)
      })
    }})
    sqlContext.createDataFrame(rowRDD, sparkSqlSchema)
  }

  override def initialize(sqlContext: SQLContext, config: PhaseConfig, logger: PhaseLogger): Unit = {
    val userConfig = config.asInstanceOf[UserTransformConfig]

    val avroSchemaJson = userConfig.getConfigJsValue("avroSchema") match {
      case Some(s) => s
      case None => throw new IllegalArgumentException("avroSchema must be set in the config")
    }
    avroSchemaStr = avroSchemaJson.toString

    val parser = new Schema.Parser()
    val avroSchema = parser.parse(avroSchemaJson.toString)
    sparkSqlSchema = AvroToSchema.getSchema(avroSchema)
  }

  override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = {
    AvroRDDToDataFrame(sqlContext, df.rdd)
  }
} 
Example 46
Source File: ExcelOutputWriterFactory.scala    From spark-hadoopoffice-ds   with Apache License 2.0 5 votes vote down vote up
package org.zuinnote.spark.office.excel

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.sql.execution.datasources.{ OutputWriter, OutputWriterFactory }
import org.apache.spark.sql.types.StructType

import org.zuinnote.hadoop.office.format.mapreduce.ExcelFileOutputFormat
import org.zuinnote.hadoop.office.format.common.HadoopOfficeWriteConfiguration

private[excel] class ExcelOutputWriterFactory(options: Map[String, String]) extends OutputWriterFactory {

  def newInstance(
    path:       String,
    bucketId:   Option[Int],
    dataSchema: StructType,
    context:    TaskAttemptContext): OutputWriter = {
    new ExcelOutputWriter(path, dataSchema, context, options)
  }

  def newInstance(
    path:       String,
    dataSchema: StructType,
    context:    TaskAttemptContext): OutputWriter = {
    new ExcelOutputWriter(path, dataSchema, context, options)
  }

  def getFileExtension(context: TaskAttemptContext): String = {
    val conf = context.getConfiguration();
    val defaultConf = conf.get(HadoopOfficeWriteConfiguration.CONF_MIMETYPE, ExcelFileOutputFormat.DEFAULT_MIMETYPE);
    conf.set(HadoopOfficeWriteConfiguration.CONF_MIMETYPE, defaultConf);
    ExcelFileOutputFormat.getSuffix(conf.get(HadoopOfficeWriteConfiguration.CONF_MIMETYPE))
  }

} 
Example 47
Source File: ExcelRelation.scala    From spark-hadoopoffice-ds   with Apache License 2.0 5 votes vote down vote up
package org.zuinnote.spark.office.excel

import scala.collection.JavaConversions._

import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.SQLContext

import org.apache.spark.sql._
import org.apache.spark.rdd.RDD

import org.apache.hadoop.conf._
import org.apache.hadoop.mapreduce._

import org.apache.commons.logging.LogFactory
import org.apache.commons.logging.Log

import org.zuinnote.hadoop.office.format.common.dao._
import org.zuinnote.hadoop.office.format.mapreduce._

import org.zuinnote.spark.office.excel.util.ExcelFile


  override def buildScan: RDD[Row] = {
    // read ExcelRows
    val excelRowsRDD = ExcelFile.load(sqlContext, location, hadoopParams)
    // map to schema
    val schemaFields = schema.fields
    excelRowsRDD.flatMap(excelKeyValueTuple => {
      // map the Excel row data structure to a Spark SQL schema
      val rowArray = new Array[Any](excelKeyValueTuple._2.get.length)
      var i = 0;
      for (x <- excelKeyValueTuple._2.get) { // parse through the SpreadSheetCellDAO
        val spreadSheetCellDAOStructArray = new Array[String](schemaFields.length)
        val currentSpreadSheetCellDAO: Array[SpreadSheetCellDAO] = excelKeyValueTuple._2.get.asInstanceOf[Array[SpreadSheetCellDAO]]
        spreadSheetCellDAOStructArray(0) = currentSpreadSheetCellDAO(i).getFormattedValue
        spreadSheetCellDAOStructArray(1) = currentSpreadSheetCellDAO(i).getComment
        spreadSheetCellDAOStructArray(2) = currentSpreadSheetCellDAO(i).getFormula
        spreadSheetCellDAOStructArray(3) = currentSpreadSheetCellDAO(i).getAddress
        spreadSheetCellDAOStructArray(4) = currentSpreadSheetCellDAO(i).getSheetName
        // add row representing one Excel row
        rowArray(i) = spreadSheetCellDAOStructArray
        i += 1
      }
      Some(Row.fromSeq(rowArray))
    })

  }

} 
Example 48
Source File: ActionsHandler.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.http

import java.util.Properties
import scala.collection.mutable.ArrayBuffer
import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import java.sql.Timestamp
import org.apache.spark.sql.types.StructType
import java.util.concurrent.atomic.AtomicInteger


	def listActionHandlerEntries(requestBody: Map[String, Any]): ActionHandlerEntries;
	def destroy();
}

trait ActionsHandlerFactory {
	def createInstance(params: Params): ActionsHandler;
}

abstract class AbstractActionsHandler extends ActionsHandler {
	def getRequiredParam(requestBody: Map[String, Any], key: String): Any = {
		val opt = requestBody.get(key);
		if (opt.isEmpty) {
			throw new MissingRequiredRequestParameterException(key);
		}

		opt.get;
	}

	override def destroy() = {
	}
}

class NullActionsHandler extends AbstractActionsHandler {
	override def listActionHandlerEntries(requestBody: Map[String, Any]): ActionHandlerEntries = new ActionHandlerEntries() {
		def apply(action: String) = Map[String, Any]();
		//yes, do nothing
		def isDefinedAt(action: String) = false;
	};
}

//rich row with extra info: id, time stamp, ...
case class RowEx(originalRow: Row, batchId: Long, offsetInBatch: Long, timestamp: Timestamp) {
	def withTimestamp(): Row = Row.fromSeq(originalRow.toSeq :+ timestamp);
	def withId(): Row = Row.fromSeq(originalRow.toSeq :+ s"$batchId-$offsetInBatch");
	def extra: (Long, Long, Timestamp) = { (batchId, offsetInBatch, timestamp) };
}

trait SendStreamActionSupport {
	def onReceiveStream(topic: String, rows: Array[RowEx]);
	def getRequiredParam(requestBody: Map[String, Any], key: String): Any;

	val listeners = ArrayBuffer[StreamListener]();

	def addListener(listener: StreamListener): this.type = {
		listeners += listener;
		this;
	}

	protected def notifyListeners(topic: String, data: Array[RowEx]) {
		listeners.foreach { _.onArrive(topic, data); }
	}

	def handleSendStream(requestBody: Map[String, Any]): Map[String, Any] = {
		val topic = getRequiredParam(requestBody, "topic").asInstanceOf[String];
		val batchId = getRequiredParam(requestBody, "batchId").asInstanceOf[Long];
		val rows = getRequiredParam(requestBody, "rows").asInstanceOf[Array[Row]];
		val ts = new Timestamp(System.currentTimeMillis());
		var index = -1;
		val rows2 = rows.map { row ⇒
			index += 1;
			RowEx(Row.fromSeq(row.toSeq), batchId, index, ts)
		}

		onReceiveStream(topic, rows2);
		notifyListeners(topic, rows2);
		Map("rowsCount" -> rows.size);
	}
} 
Example 49
Source File: utils.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.http

import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.TimestampType
import org.apache.spark.SparkConf
import org.apache.commons.io.IOUtils
import org.apache.spark.serializer.KryoSerializer
import java.io.InputStream
import com.esotericsoftware.kryo.io.Input
import java.io.ByteArrayOutputStream

class WrongArgumentException(name: String, value: Any)
		extends RuntimeException(s"wrong argument: $name=$value") {
}

class MissingRequiredArgumentException(map: Map[String, String], paramName: String)
		extends RuntimeException(s"missing required argument: $paramName, all parameters=$map") {
}

class InvalidSerializerNameException(serializerName: String)
		extends RuntimeException(s"invalid serializer name: $serializerName") {
}

object SchemaUtils {
	def buildSchema(schema: StructType, includesTimestamp: Boolean, timestampColumnName: String = "_TIMESTAMP_"): StructType = {
		if (!includesTimestamp)
			schema;
		else
			StructType(schema.fields.toSeq :+ StructField(timestampColumnName, TimestampType, false));
	}
}

object Params {
	
	def deserialize(bytes: Array[Byte]): Any = {
		val kryo = kryoSerializer.newKryo();
		val input = new Input();
		input.setBuffer(bytes);
		kryo.readClassAndObject(input);
	}
} 
Example 50
Source File: HttpStreamServerClientTest.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.http.HttpStreamClient
import org.junit.Assert
import org.junit.Test
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ByteType
import org.apache.spark.sql.execution.streaming.http.HttpStreamServer
import org.apache.spark.sql.execution.streaming.http.StreamPrinter
import org.apache.spark.sql.execution.streaming.http.HttpStreamServerSideException


class HttpStreamServerClientTest {
	val ROWS1 = Array(Row("hello1", 1, true, 0.1f, 0.1d, 1L, '1'.toByte),
		Row("hello2", 2, false, 0.2f, 0.2d, 2L, '2'.toByte),
		Row("hello3", 3, true, 0.3f, 0.3d, 3L, '3'.toByte));

	val ROWS2 = Array(Row("hello"),
		Row("world"),
		Row("bye"),
		Row("world"));

	@Test
	def testHttpStreamIO() {
		//starts a http server
		val kryoSerializer = new KryoSerializer(new SparkConf());
		val server = HttpStreamServer.start("/xxxx", 8080);

		val spark = SparkSession.builder.appName("testHttpTextSink").master("local[4]")
			.getOrCreate();
		spark.conf.set("spark.sql.streaming.checkpointLocation", "/tmp/");

		val sqlContext = spark.sqlContext;
		import spark.implicits._
		//add a local message buffer to server, with 2 topics registered
		server.withBuffer()
			.addListener(new StreamPrinter())
			.createTopic[(String, Int, Boolean, Float, Double, Long, Byte)]("topic-1")
			.createTopic[String]("topic-2");

		val client = HttpStreamClient.connect("http://localhost:8080/xxxx");
		//tests schema of topics
		val schema1 = client.fetchSchema("topic-1");
		Assert.assertArrayEquals(Array[Object](StringType, IntegerType, BooleanType, FloatType, DoubleType, LongType, ByteType),
			schema1.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		val schema2 = client.fetchSchema("topic-2");
		Assert.assertArrayEquals(Array[Object](StringType),
			schema2.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		//prepare to consume messages
		val sid1 = client.subscribe("topic-1")._1;
		val sid2 = client.subscribe("topic-2")._1;

		//produces some data
		client.sendRows("topic-1", 1, ROWS1);

		val sid4 = client.subscribe("topic-1")._1;
		val sid5 = client.subscribe("topic-2")._1;

		client.sendRows("topic-2", 1, ROWS2);

		//consumes data
		val fetched = client.fetchStream(sid1).map(_.originalRow);
		Assert.assertArrayEquals(ROWS1.asInstanceOf[Array[Object]], fetched.asInstanceOf[Array[Object]]);
		//it is empty now
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid1).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid2).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid4).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);

		client.unsubscribe(sid4);
		try {
			client.fetchStream(sid4);
			//exception should be thrown, because subscriber id is invalidated
			Assert.assertTrue(false);
		}
		catch {
			case e: Throwable ⇒
				e.printStackTrace();
				Assert.assertEquals(classOf[HttpStreamServerSideException], e.getClass);
		}

		server.stop();
	}
} 
Example 51
Source File: MultinomialLogisticRegressionParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.parity.classification

import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.{Matrices, Vectors}
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class MultinomialLogisticRegressionParitySpec extends SparkParityBase {

  val labels = Seq(0.0, 1.0, 2.0, 0.0, 1.0, 2.0)
  val ages = Seq(15, 30, 40, 50, 15, 80)
  val heights = Seq(175, 190, 155, 160, 170, 180)
  val weights = Seq(67, 100, 57, 56, 56, 88)

  val rows = spark.sparkContext.parallelize(Seq.tabulate(6) { i => Row(labels(i), ages(i), heights(i), weights(i)) })
  val schema = new StructType().add("label", DoubleType, nullable = false)
    .add("age", IntegerType, nullable = false)
    .add("height", IntegerType, nullable = false)
    .add("weight", IntegerType, nullable = false)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)

  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(
    new VectorAssembler().
      setInputCols(Array("age", "height", "weight")).
      setOutputCol("features"),
    new LogisticRegressionModel(uid = "logr", 
      coefficientMatrix = Matrices.dense(3, 3, Array(-1.3920551604166562, -0.13119545493644366, 1.5232506153530998, 0.3129112131192873, -0.21959056436528473, -0.09332064875400257, -0.24696506013528507, 0.6122879917796569, -0.36532293164437174)),
      interceptVector = Vectors.dense(0.4965574044951358, -2.1486146169780063, 1.6520572124828703),
      numClasses = 3, isMultinomial = true))).fit(dataset)
} 
Example 52
Source File: SparkSupport.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.spark

import java.net.URI

import ml.combust.bundle.dsl.Bundle
import ml.combust.bundle.{BundleFile, BundleWriter}
import ml.combust.mleap.core.types
import ml.combust.mleap.runtime.frame
import ml.combust.mleap.runtime.frame.Row
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.mleap.TypeConverters
import org.apache.spark.sql.types.StructType
import resource._

import scala.util.Try


trait SparkSupport {
  implicit class SparkTransformerOps(transformer: Transformer) {
    def writeBundle: BundleWriter[SparkBundleContext, Transformer] = BundleWriter(transformer)
  }

  implicit class SparkBundleFileOps(file: BundleFile) {
    def loadSparkBundle()
                       (implicit context: SparkBundleContext): Try[Bundle[Transformer]] = file.load()
  }

  implicit class URIBundleFileOps(uri: URI) {
    def loadMleapBundle()
                       (implicit context: SparkBundleContext): Try[Bundle[Transformer]] = {
      (for (bf <- managed(BundleFile.load(uri))) yield {
        bf.load[SparkBundleContext, Transformer]().get
      }).tried
    }
  }

  implicit class MleapSparkTransformerOps[T <: frame.Transformer](transformer: T) {
    def sparkTransform(dataset: DataFrame): DataFrame = {
      transformer.transform(dataset.toSparkLeapFrame).get.toSpark
    }
  }

  implicit class SparkDataFrameOps(dataset: DataFrame) {
    def toSparkLeapFrame: SparkLeapFrame = {
      val spec = dataset.schema.fields.
        map(f => TypeConverters.sparkToMleapConverter(dataset, f))
      val schema = types.StructType(spec.map(_._1)).get
      val converters = spec.map(_._2)
      val data = dataset.rdd.map(r => {
        val values = r.toSeq.zip(converters).map {
          case (v, c) => c(v)
        }
        Row(values: _*)
      })

      SparkLeapFrame(schema, data, dataset.sqlContext)
    }

    def mleapSchema: types.StructType = TypeConverters.sparkSchemaToMleapSchema(dataset)
  }

  implicit class MleapSchemaOps(schema: types.StructType) {
    def toSpark: StructType = TypeConverters.mleapSchemaToSparkSchema(schema)
  }
}

object SparkSupport extends SparkSupport 
Example 53
Source File: SparkTransformBuilderSpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.spark

import java.util.UUID

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StructType}
import SparkSupport._
import ml.combust.mleap.core.{Model, types}
import ml.combust.mleap.core.types.{NodeShape, ScalarType, StructField}
import ml.combust.mleap.runtime.frame.{FrameBuilder, Transformer}
import org.scalatest.FunSpec

import scala.collection.JavaConverters._
import scala.util.Try


case class MyTransformer() extends Transformer {
  override val uid: String = UUID.randomUUID().toString

  override def transform[TB <: FrameBuilder[TB]](builder: TB): Try[TB] = {
    builder.withColumns(Seq("output1", "output2"), "input") {
      (input: Double) => (input + 23, input.toString)
    }
  }

  override val shape: NodeShape = NodeShape().withStandardInput("input").
    withOutput("output1", "output1").withOutput("output2", "output2")

  override val model: Model = new Model {
    override def inputSchema: types.StructType = types.StructType("input" -> ScalarType.Double).get

    override def outputSchema: types.StructType = types.StructType("output1" -> ScalarType.Double,
      "output2" -> ScalarType.String).get
  }
}

class SparkTransformBuilderSpec extends FunSpec {
  describe("transformer with multiple outputs") {
    it("works with Spark as well") {
      val spark = SparkSession.builder().
        appName("Spark/MLeap Parity Tests").
        master("local[2]").
        getOrCreate()
      val schema = new StructType().
        add("input", DoubleType)
      val data = Seq(Row(45.7d)).asJava
      val dataset = spark.createDataFrame(data, schema)
      val transformer = MyTransformer()
      val outputDataset = transformer.sparkTransform(dataset).collect()

      assert(outputDataset.head.getDouble(1) == 68.7)
      assert(outputDataset.head.getString(2) == "45.7")
    }
  }

  describe("input/output schema") {
    it("has the correct inputs and outputs") {
      val transformer = MyTransformer()
      assert(transformer.schema.fields ==
        Seq(StructField("input", types.ScalarType.Double),
          StructField("output1", types.ScalarType.Double),
          StructField("output2", types.ScalarType.String)))
    }
  }
} 
Example 54
Source File: MathUnary.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.feature

import ml.combust.mleap.core.feature.{MathUnaryModel, UnaryOperation}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.functions.udf


    private val className = classOf[MathUnary].getName

    override def load(path: String): MathUnary = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

      val dataPath = new Path(path, "data").toString

      val data = sparkSession.read.parquet(dataPath).select("operation").head()
      val operation = data.getAs[String](0)

      val model = MathUnaryModel(UnaryOperation.forName(operation))
      val transformer = new MathUnary(metadata.uid, model)

      metadata.getAndSetParams(transformer)
      transformer
    }
  }

} 
Example 55
Source File: ImputerParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.parity.feature

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.mleap.feature.Imputer
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.sql._
import org.apache.spark.sql.types.{DoubleType, StructType}

import scala.util.Random


class ImputerParitySpec extends SparkParityBase {
  def randomRow(): Row = {
    if(Random.nextBoolean()) {
      if(Random.nextBoolean()) {
        Row(23.4)
      } else { Row(Random.nextDouble()) }
    } else {
      Row(33.2)
    }
  }
  val rows = spark.sparkContext.parallelize(Seq.tabulate(100) { i => randomRow() })
  val schema = new StructType().add("mv", DoubleType, nullable = true)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)
  override val sparkTransformer: Transformer = new Imputer(uid = "imputer").
    setInputCol("mv").
    setOutputCol("mv_imputed").
    setMissingValue(23.4).
    setStrategy("mean").fit(dataset)
} 
Example 56
Source File: StringMapParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.mleap.parity.feature

import ml.combust.mleap.core.feature.{HandleInvalid, StringMapModel}
import org.apache.spark.ml.mleap.feature.StringMap
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StringType, StructType}

class StringMapParitySpec extends SparkParityBase {
  val names = Seq("alice", "andy", "kevin")
  val rows = spark.sparkContext.parallelize(Seq.tabulate(3) { i => Row(names(i)) })
  val schema = new StructType().add("name", StringType, nullable = false)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)

  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(
    new StringMap(uid = "string_map", model = new StringMapModel(
      Map("alice" -> 0, "andy" -> 1, "kevin" -> 2)
    )).setInputCol("name").setOutputCol("index"),
    new StringMap(uid = "string_map2", model = new StringMapModel(
      // This map is missing the label "kevin". Exception is thrown unless HandleInvalid.Keep is set.
      Map("alice" -> 0, "andy" -> 1),
      handleInvalid = HandleInvalid.Keep, defaultValue = 1.0
    )).setInputCol("name").setOutputCol("index2")

  )).fit(dataset)
} 
Example 57
Source File: WrappersSpec.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb.spark

import com.example.protos.wrappers._
import org.apache.spark.sql.SparkSession
import org.apache.hadoop.io.ArrayPrimitiveWritable
import scalapb.GeneratedMessageCompanion
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row

import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class WrappersSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
  val spark: SparkSession = SparkSession
    .builder()
    .appName("ScalaPB Demo")
    .master("local[2]")
    .getOrCreate()

  import spark.implicits.StringToColumn

  val data = Seq(
    PrimitiveWrappers(
      intValue = Option(45),
      stringValue = Option("boo"),
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    ),
    PrimitiveWrappers(
      intValue = None,
      stringValue = None,
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    )
  )

  "converting df with primitive wrappers" should "work with primitive implicits" in {
    import ProtoSQL.withPrimitiveWrappers.implicits._
    val df = ProtoSQL.withPrimitiveWrappers.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        IntegerType,
        StringType,
        ArrayType(IntegerType, false),
        ArrayType(StringType, false)
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(45, "boo", Seq(17, 19, 25), Seq("foo", "bar")),
        Row(null, null, Seq(17, 19, 25), Seq("foo", "bar"))
      )
    )
  }

  "converting df with primitive wrappers" should "work with default implicits" in {
    import ProtoSQL.implicits._
    val df = ProtoSQL.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        StructType(Seq(StructField("value", IntegerType, true))),
        StructType(Seq(StructField("value", StringType, true))),
        ArrayType(
          StructType(Seq(StructField("value", IntegerType, true))),
          false
        ),
        ArrayType(
          StructType(Seq(StructField("value", StringType, true))),
          false
        )
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(
          Row(45),
          Row("boo"),
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        ),
        Row(
          null,
          null,
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        )
      )
    )
  }
} 
Example 58
Source File: L8-4DataFrameCreationSchema.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.desc
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object DataframeCreationApp2 {

  def main(args: Array[String]) {
    if (args.length != 5) {
      System.err.println(
        "Usage: CdrDataframeApp2 <appname> <batchInterval> <hostname> <port> <schemaPath>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port, schemaFile) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)

    val schemaJson = scala.io.Source.fromFile(schemaFile).mkString
    val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = sqlC.createDataFrame(rdd.map(c => Row(c: _*)), schema)
        
        cdrs.groupBy("countryCode").count().orderBy(desc("count")).show(5)
      })

    ssc.start()
    ssc.awaitTermination()

  }
} 
Example 59
Source File: L8-35DataFrameExamplesRDD.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.json4s.DefaultFormats

object CdrDataframeExamplesRDDApp {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 5) {
      System.err.println(
        "Usage: CdrDataframeExamplesRDDApp <appname> <batchInterval> <hostname> <schemaPath>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port, schemaFile) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._
    implicit val formats = DefaultFormats

    val schemaJson = scala.io.Source.fromFile(schemaFile).mkString
    val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF()
        val highInternet = sqlC.createDataFrame(cdrs.rdd.filter(r => r.getFloat(3) + r.getFloat(4) >= r.getFloat(5) + r.getFloat(6)), schema)
        val highOther = cdrs.except(highInternet)
        val highInternetGrid = highInternet.select("squareId", "countryCode").dropDuplicates()
        val highOtherGrid = highOther.select("squareId", "countryCode").dropDuplicates()
        highOtherGrid.except(highInternetGrid).show()
        highInternetGrid.except(highOtherGrid).show()
      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 60
Source File: StreamStaticJoiner.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import knolx.Config._
import knolx.KnolXLogger
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.functions.{col, from_json}
import org.apache.spark.sql.types.StructType


object StreamStaticJoiner extends App with KnolXLogger {
  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  info("Static Dataframe")
  val companiesDF = spark.read.option("header", "true").csv("src/main/resources/companies.csv")
  companiesDF.show(false)

  info("Original Streaming Dataframe")
  val schema = ScalaReflection.schemaFor[Stock].dataType.asInstanceOf[StructType]
  val stockStreamDF =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", topic)
      .load()
      .select(from_json(col("value").cast("string"), schema).as("value"))
      .select("value.*")

  stockStreamDF.printSchema()
  stockStreamDF.writeStream.format("console").start()

  info("Filtered Streaming Dataframe")
  val filteredStockStreamDF = stockStreamDF.join(companiesDF, "companyName")
  val filteredStockStreamingQuery = filteredStockStreamDF.writeStream.format("console").start()

  info("Waiting for the query to terminate...")
  filteredStockStreamingQuery.awaitTermination()
  filteredStockStreamingQuery.stop()
} 
Example 61
Source File: StreamStreamOuterJoiner.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import knolx.Config._
import knolx.KnolXLogger
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.functions.{col, expr, from_json}
import org.apache.spark.sql.types.StructType


object StreamStreamOuterJoiner extends App with KnolXLogger {
  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  info("Streaming companies Dataframe")
  val companiesDF =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", companiesTopic)
      .load()
      .select(col("value").cast("string").as("companyName"),
        col("timestamp").as("companyTradingTime"))
      .withWatermark("companyTradingTime", "10 seconds")

  companiesDF.writeStream.format("console").option("truncate", false).start()

  info("Original Streaming Dataframe")
  val schema = ScalaReflection.schemaFor[Stock].dataType.asInstanceOf[StructType]
  val stockStreamDF =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", stocksTopic)
      .load()
      .select(from_json(col("value").cast("string"), schema).as("value"),
        col("timestamp").as("stockInputTime"))
      .select("value.*", "stockInputTime")
      .withWatermark("stockInputTime", "10 seconds")

  info("Filtered Streaming Dataframe")
  val filteredStockStreamDF = stockStreamDF.join(companiesDF,
    expr("companyName = stockName AND stockInputTime >= companyTradingTime AND stockInputTime <= companyTradingTime + interval 20 seconds"),
    joinType = "leftOuter")
  val filteredStockStreamingQuery = filteredStockStreamDF.writeStream.format("console").option("truncate", false).start()

  info("Waiting for the query to terminate...")
  filteredStockStreamingQuery.awaitTermination()
  filteredStockStreamingQuery.stop()
} 
Example 62
Source File: StreamStreamJoiner.scala    From structured-streaming-application   with Apache License 2.0 5 votes vote down vote up
package knolx.spark

import knolx.Config._
import knolx.KnolXLogger
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.functions.{col, expr, from_json}
import org.apache.spark.sql.types.StructType


object StreamStreamJoiner extends App with KnolXLogger {
  info("Creating Spark Session")
  val spark = SparkSession.builder().master(sparkMaster).appName(sparkAppName).getOrCreate()
  spark.sparkContext.setLogLevel("WARN")

  info("Streaming companies Dataframe")
  val companiesDF =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", companiesTopic)
      .load()
      .select(col("value").cast("string").as("companyName"),
        col("timestamp").as("companyTradingTime"))

  companiesDF.writeStream.format("console").option("truncate", false).start()

  info("Original Streaming Dataframe")
  val schema = ScalaReflection.schemaFor[Stock].dataType.asInstanceOf[StructType]
  val stockStreamDF =
    spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", bootstrapServer)
      .option("subscribe", stocksTopic)
      .load()
      .select(from_json(col("value").cast("string"), schema).as("value"),
        col("timestamp").as("stockInputTime"))
      .select("value.*", "stockInputTime")

  info("Filtered Streaming Dataframe")
  val filteredStockStreamDF = stockStreamDF.join(companiesDF,
    expr("companyName = stockName AND stockInputTime >= companyTradingTime AND stockInputTime <= companyTradingTime + interval 20 seconds"))
  val filteredStockStreamingQuery = filteredStockStreamDF.writeStream.format("console").option("truncate", false).start()

  info("Waiting for the query to terminate...")
  filteredStockStreamingQuery.awaitTermination()
  filteredStockStreamingQuery.stop()
} 
Example 63
Source File: EventHubsRelation.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.eventhubs.rdd.{ EventHubsRDD, OffsetRange }
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ Row, SQLContext }
import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.StructType

import scala.language.postfixOps


private[eventhubs] class EventHubsRelation(override val sqlContext: SQLContext,
                                           parameters: Map[String, String])
    extends BaseRelation
    with TableScan
    with Logging {

  import org.apache.spark.eventhubs._

  private val ehConf = EventHubsConf.toConf(parameters)
  private val eventHubClient = EventHubsSourceProvider.clientFactory(parameters)(ehConf)

  override def schema: StructType = EventHubsSourceProvider.eventHubsSchema

  override def buildScan(): RDD[Row] = {
    val partitionCount: Int = eventHubClient.partitionCount

    val fromSeqNos = eventHubClient.translate(ehConf, partitionCount)
    val untilSeqNos = eventHubClient.translate(ehConf, partitionCount, useStart = false)

    require(fromSeqNos.forall(f => f._2 >= 0L),
            "Currently only sequence numbers can be passed in your starting positions.")
    require(untilSeqNos.forall(u => u._2 >= 0L),
            "Currently only sequence numbers can be passed in your ending positions.")

    val offsetRanges = untilSeqNos.keySet.map { p =>
      val fromSeqNo = fromSeqNos
        .getOrElse(p, throw new IllegalStateException(s"$p doesn't have a fromSeqNo"))
      val untilSeqNo = untilSeqNos(p)
      OffsetRange(ehConf.name, p, fromSeqNo, untilSeqNo, None)
    }.toArray
    eventHubClient.close()

    logInfo(
      "GetBatch generating RDD of with offsetRanges: " +
        offsetRanges.sortBy(_.nameAndPartition.toString).mkString(", "))

    val rdd = EventHubsSourceProvider.toInternalRow(
      new EventHubsRDD(sqlContext.sparkContext, ehConf.trimmed, offsetRanges))
    sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = false).rdd
  }
} 
Example 64
Source File: DefaultSource.scala    From magellan   with Apache License 2.0 5 votes vote down vote up
package magellan

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, SchemaRelationProvider, RelationProvider}


class DefaultSource extends RelationProvider
  with SchemaRelationProvider {

  override def createRelation(sqlContext: SQLContext,
    parameters: Map[String, String]): BaseRelation = createRelation(sqlContext, parameters, null)

  override def createRelation(sqlContext: SQLContext,
    parameters: Map[String, String], schema: StructType): BaseRelation = {
    val path = parameters.getOrElse("path", sys.error("'path' must be specified for Shapefiles."))
    val t = parameters.getOrElse("type", "shapefile")
    t match {
      case "shapefile" => new ShapeFileRelation(path, parameters)(sqlContext)
      case "geojson" => new GeoJSONRelation(path, parameters)(sqlContext)
      case "osm" => new OsmFileRelation(path, parameters)(sqlContext)
      case _ => ???
    }
  }

} 
Example 65
Source File: DefaultSource.scala    From spark-excel   with Apache License 2.0 5 votes vote down vote up
package com.crealytics.spark.excel

import com.crealytics.spark.excel.Utils._
import org.apache.hadoop.fs.Path
import org.apache.poi.ss.util.{CellRangeAddress, CellReference}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

import scala.util.Try

class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {

  
  override def createRelation(
    sqlContext: SQLContext,
    parameters: Map[String, String],
    schema: StructType
  ): ExcelRelation = {
    val wbReader = WorkbookReader(parameters, sqlContext.sparkContext.hadoopConfiguration)
    val dataLocator = DataLocator(parameters)
    ExcelRelation(
      header = checkParameter(parameters, "header").toBoolean,
      treatEmptyValuesAsNulls = parameters.get("treatEmptyValuesAsNulls").fold(false)(_.toBoolean),
      userSchema = Option(schema),
      inferSheetSchema = parameters.get("inferSchema").fold(false)(_.toBoolean),
      addColorColumns = parameters.get("addColorColumns").fold(false)(_.toBoolean),
      timestampFormat = parameters.get("timestampFormat"),
      excerptSize = parameters.get("excerptSize").fold(10)(_.toInt),
      dataLocator = dataLocator,
      workbookReader = wbReader
    )(sqlContext)
  }

  override def createRelation(
    sqlContext: SQLContext,
    mode: SaveMode,
    parameters: Map[String, String],
    data: DataFrame
  ): BaseRelation = {
    val path = checkParameter(parameters, "path")
    val header = checkParameter(parameters, "header").toBoolean
    val filesystemPath = new Path(path)
    val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
    new ExcelFileSaver(
      fs,
      filesystemPath,
      data,
      saveMode = mode,
      header = header,
      dataLocator = DataLocator(parameters)
    ).save()

    createRelation(sqlContext, parameters, data.schema)
  }

  // Forces a Parameter to exist, otherwise an exception is thrown.
  private def checkParameter(map: Map[String, String], param: String): String = {
    if (!map.contains(param)) {
      throw new IllegalArgumentException(s"Parameter ${'"'}$param${'"'} is missing in options.")
    } else {
      map.apply(param)
    }
  }
} 
Example 66
Source File: KustoResponseDeserializer.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark.datasource

import java.sql.Timestamp
import java.util

import com.microsoft.azure.kusto.data.{KustoResultColumn, KustoResultSetTable, Results}
import com.microsoft.kusto.spark.utils.DataTypeMapping
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{StructType, _}
import org.joda.time.DateTime

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object KustoResponseDeserializer {
  def apply(kustoResult: KustoResultSetTable): KustoResponseDeserializer = new KustoResponseDeserializer(kustoResult)
}

// Timespan columns are casted to strings in kusto side. A simple test to compare the translation to a Duration string
// in the format of timespan resulted in less performance. One way was using a new expression that extends UnaryExpression,
// second was by a udf function, both were less performant.
case class KustoSchema(sparkSchema: StructType, toStringCastedColumns: Set[String])

class KustoResponseDeserializer(val kustoResult: KustoResultSetTable) {
  val schema: KustoSchema = getSchemaFromKustoResult

  private def getValueTransformer(valueType: String): Any => Any = {

    valueType.toLowerCase() match {
      case "string" => value: Any => value
      case "int64" => value: Any => value
      case "datetime" => value: Any => new Timestamp(new DateTime(value).getMillis)
      case "timespan" => value: Any => value
      case "sbyte" => value: Any => value
      case "long" => value: Any => value match {
        case i: Int => i.toLong
        case _ => value.asInstanceOf[Long]
      }
      case "double" => value: Any => value
      case "decimal" => value: Any => BigDecimal(value.asInstanceOf[String])
      case "int" => value: Any => value
      case "int32" => value: Any => value
      case "bool" => value: Any => value
      case "real" => value: Any => value
      case _ => value: Any => value.toString
      }
  }

   private def getSchemaFromKustoResult: KustoSchema = {
    if (kustoResult.getColumns.isEmpty) {
      KustoSchema(StructType(List()), Set())
    } else {
      val columns = kustoResult.getColumns

      KustoSchema(StructType(columns.map(col => StructField(col.getColumnName,
            DataTypeMapping.kustoTypeToSparkTypeMap.getOrElse(col.getColumnType.toLowerCase, StringType)))),
        columns.filter(c => c.getColumnType.equalsIgnoreCase("TimeSpan")).map(c => c.getColumnName).toSet)
    }
  }

  def getSchema: KustoSchema = { schema }

  def toRows: java.util.List[Row] = {
    val columnInOrder = kustoResult.getColumns
    val value: util.ArrayList[Row] = new util.ArrayList[Row](kustoResult.count())

//     Calculate the transformer function for each column to use later by order
    val valueTransformers: mutable.Seq[Any => Any] = columnInOrder.map(col => getValueTransformer(col.getColumnType))
    kustoResult.getData.asScala.foreach(row => {
      val genericRow = row.toArray().zipWithIndex.map(
        column => {
          if (column._1 == null) null else valueTransformers(column._2)(column._1)
        })
      value.add(new GenericRowWithSchema(genericRow, schema.sparkSchema))
    })

    value
  }

//  private def getOrderedColumnName = {
//    val columnInOrder = ArrayBuffer.fill(kustoResult.getColumnNameToIndex.size()){ "" }
//    kustoResult.getColumns.foreach((columnIndexPair: KustoResultColumn) => columnInOrder(columnIndexPair.) = columnIndexPair._1)
//    columnInOrder
//  }
} 
Example 67
Source File: DataTypeMapping.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark.utils

import org.apache.spark.sql.types.DataTypes._
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, MapType, StructType}

object DataTypeMapping {

  val kustoTypeToSparkTypeMap: Map[String, DataType] = Map(
    "string" -> StringType,
    "long" -> LongType,
    "datetime" -> TimestampType,// Kusto datetime is equivalent to TimestampType
    "timespan" -> StringType,
    "bool" -> BooleanType,
    "real" -> DoubleType,
    // Can be partitioned differently between precision and scale, total must be 34 to match .Net SqlDecimal
    "decimal" -> DataTypes.createDecimalType(20,14),
    "guid" -> StringType,
    "int" -> IntegerType,
    "dynamic" -> StringType
  )

  val kustoJavaTypeToSparkTypeMap: Map[String, DataType] = Map(
    "string" -> StringType,
    "int64" -> LongType,
    "datetime" -> TimestampType,
    "timespan" -> StringType,
    "sbyte" -> BooleanType,
    "double" -> DoubleType,
    "sqldecimal" -> DataTypes.createDecimalType(20,14),
    "guid" -> StringType,
    "int32" -> IntegerType,
    "object" -> StringType
  )

  val sparkTypeToKustoTypeMap: Map[DataType, String] = Map(
    StringType ->  "string",
    BooleanType -> "bool",
    DateType -> "datetime",
    TimestampType -> "datetime",
    DataTypes.createDecimalType() -> "decimal",
    DoubleType -> "real",
    FloatType -> "real",
    ByteType -> "int",
    IntegerType -> "int",
    LongType ->  "long",
    ShortType ->  "int"
  )

  def getSparkTypeToKustoTypeMap(fieldType: DataType): String ={
      if(fieldType.isInstanceOf[DecimalType]) "decimal"
      else if (fieldType.isInstanceOf[ArrayType] || fieldType.isInstanceOf[StructType] || fieldType.isInstanceOf[MapType]) "dynamic"
      else DataTypeMapping.sparkTypeToKustoTypeMap.getOrElse(fieldType, "string")
  }
} 
Example 68
Source File: KustoSourceTests.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark

import com.microsoft.kusto.spark.datasource.KustoSourceOptions
import com.microsoft.kusto.spark.utils.{KustoDataSourceUtils => KDSU}
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class KustoSourceTests extends FlatSpec with MockFactory with Matchers with BeforeAndAfterAll {
  private val loggingLevel: Option[String] = Option(System.getProperty("logLevel"))
  if (loggingLevel.isDefined) KDSU.setLoggingLevel(loggingLevel.get)

  private val nofExecutors = 4
  private val spark: SparkSession = SparkSession.builder()
    .appName("KustoSource")
    .master(f"local[$nofExecutors]")
    .getOrCreate()

  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _
  private val cluster: String = "KustoCluster"
  private val database: String = "KustoDatabase"
  private val query: String = "KustoTable"
  private val appId: String = "KustoSinkTestApplication"
  private val appKey: String = "KustoSinkTestKey"
  private val appAuthorityId: String = "KustoSinkAuthorityId"

  override def beforeAll(): Unit = {
    super.beforeAll()

    sc = spark.sparkContext
    sqlContext = spark.sqlContext
  }

  override def afterAll(): Unit = {
    super.afterAll()

    sc.stop()
  }

  "KustoDataSource" should "recognize Kusto and get the correct schema" in {
    val spark: SparkSession = SparkSession.builder()
      .appName("KustoSource")
      .master(f"local[$nofExecutors]")
      .getOrCreate()

    val customSchema = "colA STRING, colB INT"

    val df = spark.sqlContext
      .read
      .format("com.microsoft.kusto.spark.datasource")
      .option(KustoSourceOptions.KUSTO_CLUSTER, cluster)
      .option(KustoSourceOptions.KUSTO_DATABASE, database)
      .option(KustoSourceOptions.KUSTO_QUERY, query)
      .option(KustoSourceOptions.KUSTO_AAD_APP_ID, appId)
      .option(KustoSourceOptions.KUSTO_AAD_APP_SECRET, appKey)
      .option(KustoSourceOptions.KUSTO_AAD_AUTHORITY_ID, appAuthorityId)
      .option(KustoSourceOptions.KUSTO_CUSTOM_DATAFRAME_COLUMN_TYPES, customSchema)
      .load("src/test/resources/")

    val expected = StructType(Array(StructField("colA", StringType, nullable = true),StructField("colB", IntegerType, nullable = true)))
    assert(df.schema.equals(expected))
  }
} 
Example 69
Source File: CustomDateParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.json

import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import java.io.{Serializable => JSerializable}

import org.joda.time.DateTime

import scala.util.Try

class CustomDateParser(order: Integer,
                       inputField: Option[String],
                       outputFields: Seq[String],
                       schema: StructType,
                       properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) {

  val dateField = propertiesWithCustom.getString("dateField", "date")
  val hourField = propertiesWithCustom.getString("hourField", "hourRounded")
  val dayField = propertiesWithCustom.getString("dayField", "dayRounded")
  val weekField = propertiesWithCustom.getString("weekField", "week")
  val hourDateField = propertiesWithCustom.getString("hourDateField", "hourDate")
  val yearPrefix = propertiesWithCustom.getString("yearPrefix", "20")

  //scalastyle:off
  override def parse(row: Row): Seq[Row] = {
    val inputValue = Try(row.get(schema.fieldIndex(dateField))).toOption
    val newData = Try {
      inputValue match {
        case Some(value) =>
          val valueStr = {
            value match {
              case valueCast: Array[Byte] => new Predef.String(valueCast)
              case valueCast: String => valueCast
              case _ => value.toString
            }
          }

          val valuesParsed = Map(
            hourField -> getDateWithBeginYear(valueStr).concat(valueStr.substring(4, valueStr.length)),
            hourDateField -> getHourDate(valueStr),
            dayField -> getDateWithBeginYear(valueStr).concat(valueStr.substring(4, valueStr.length - 2)),
            weekField -> getWeek(valueStr)
          )

          outputFields.map { outputField =>
            val outputSchemaValid = outputFieldsSchema.find(field => field.name == outputField)
            outputSchemaValid match {
              case Some(outSchema) =>
                valuesParsed.get(outSchema.name) match {
                  case Some(valueParsed) =>
                    parseToOutputType(outSchema, valueParsed)
                  case None =>
                    returnWhenError(new IllegalStateException(
                      s"The values parsed don't contain the schema field: ${outSchema.name}"))
                }
              case None =>
                returnWhenError(new IllegalStateException(
                  s"Impossible to parse outputField: $outputField in the schema"))
            }
          }
        case None =>
          returnWhenError(new IllegalStateException(s"The input value is null or empty"))
      }
    }

    returnData(newData, removeInputField(row))
  }

  def getDateWithBeginYear(inputDate: String): String =
    inputDate.substring(0, inputDate.length - 4).concat(yearPrefix)

  def getHourDate(inputDate: String): Long = {
    val day = inputDate.substring(0, 2).toInt
    val month = inputDate.substring(2, 4).toInt
    val year = yearPrefix.concat(inputDate.substring(4, 6)).toInt
    val hour = inputDate.substring(6, inputDate.length).toInt
    val date = new DateTime(year, month, day, hour, 0)

    date.getMillis
  }

  def getWeek(inputDate: String): Int = {
    val day = inputDate.substring(0, 2).toInt
    val month = inputDate.substring(2, 4).toInt
    val year = yearPrefix.concat(inputDate.substring(4, 6)).toInt
    val date = new DateTime(year, month, day, 0, 0)

    date.getWeekOfWeekyear
  }

  //scalastyle:on
} 
Example 70
Source File: DataFrameModifierHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField
import com.stratio.sparta.sdk.pipeline.output.Output
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}

object DataFrameModifierHelper {

  def applyAutoCalculateFields(dataFrame: DataFrame,
                               autoCalculateFields: Seq[AutoCalculatedField],
                               auxSchema: StructType): DataFrame =
    autoCalculateFields.headOption match {
      case Some(firstAutoCalculate) =>
        applyAutoCalculateFields(
          addColumnToDataFrame(dataFrame, firstAutoCalculate, auxSchema), autoCalculateFields.drop(1), auxSchema)
      case None =>
        dataFrame
    }

  private[driver] def addColumnToDataFrame(dataFrame: DataFrame,
                                   autoCalculateField: AutoCalculatedField,
                                   auxSchema: StructType): DataFrame = {
    (autoCalculateField.fromNotNullFields,
      autoCalculateField.fromPkFields,
      autoCalculateField.fromFields,
      autoCalculateField.fromFixedValue) match {
      case (Some(fromNotNullFields), _, _, _) =>
        val fields = fieldsWithAuxMetadata(dataFrame.schema.fields, auxSchema.fields).flatMap(field =>
          if (!field.nullable) Some(col(field.name)) else None).toSeq
        addField(fromNotNullFields.field.name, fromNotNullFields.field.outputType, dataFrame, fields)
      case (None, Some(fromPkFields), _, _) =>
        val fields = fieldsWithAuxMetadata(dataFrame.schema.fields, auxSchema.fields).flatMap(field =>
          if (field.metadata.contains(Output.PrimaryKeyMetadataKey)) Some(col(field.name)) else None).toSeq
        addField(fromPkFields.field.name, fromPkFields.field.outputType, dataFrame, fields)
      case (None, None, Some(fromFields), _) =>
        val fields = autoCalculateField.fromFields.get.fromFields.map(field => col(field))
        addField(fromFields.field.name, fromFields.field.outputType, dataFrame, fields)
      case (None, None, None, Some(fromFixedValue)) =>
        addLiteral(fromFixedValue.field.name, fromFixedValue.field.outputType, dataFrame, fromFixedValue.value)
      case _ => dataFrame
    }
  }

  private[driver] def addField(name: String, outputType: String, dataFrame: DataFrame, fields: Seq[Column]): DataFrame =
    outputType match {
      case "string" => dataFrame.withColumn(name, concat_ws(Output.Separator, fields: _*))
      case "array" => dataFrame.withColumn(name, array(fields: _*))
      case "map" => dataFrame.withColumn(name, struct(fields: _*))
      case _ => dataFrame
    }

  private[driver] def addLiteral(name: String, outputType: String, dataFrame: DataFrame, literal: String): DataFrame =
    outputType match {
      case "string" => dataFrame.withColumn(name, lit(literal))
      case "array" => dataFrame.withColumn(name, array(lit(literal)))
      case "map" => dataFrame.withColumn(name, struct(lit(literal)))
      case _ => dataFrame
    }

  private[driver] def fieldsWithAuxMetadata(dataFrameFields: Array[StructField], auxFields: Array[StructField]) =
    dataFrameFields.map(field => {
      auxFields.find(auxField => auxField.name == field.name) match {
        case Some(auxFounded) => field.copy(metadata = auxFounded.metadata)
        case None => field
      }
    })
} 
Example 71
Source File: TransformationsWriterHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import com.stratio.sparta.driver.factory.SparkContextFactory
import com.stratio.sparta.sdk.pipeline.output.Output
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.dstream.DStream

object TransformationsWriterHelper {

  def writeTransformations(input: DStream[Row],
                           inputSchema: StructType,
                           outputs: Seq[Output],
                           writerOptions: WriterOptions): Unit = {
    input.foreachRDD(rdd =>
      if (!rdd.isEmpty()) {
        val transformationsDataFrame = SparkContextFactory.sparkSessionInstance.createDataFrame(rdd, inputSchema)

        WriterHelper.write(transformationsDataFrame, writerOptions, Map.empty[String, String], outputs)
      }
    )
  }
} 
Example 72
Source File: RawDataWriterHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import com.stratio.sparta.driver.factory.SparkContextFactory
import com.stratio.sparta.driver.step.RawData
import com.stratio.sparta.sdk.pipeline.output.Output
import com.stratio.sparta.sdk.utils.AggregationTime
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.streaming.dstream.DStream


object RawDataWriterHelper {

  def writeRawData(rawData: RawData, outputs: Seq[Output], input: DStream[Row]): Unit = {
    val RawSchema = StructType(Seq(
      StructField(rawData.timeField, TimestampType, nullable = false),
      StructField(rawData.dataField, StringType, nullable = true)))
    val eventTime = AggregationTime.millisToTimeStamp(System.currentTimeMillis())

    input.map(row => Row.merge(Row(eventTime), row))
      .foreachRDD(rdd => {
        if (!rdd.isEmpty()) {
          val rawDataFrame = SparkContextFactory.sparkSessionInstance.createDataFrame(rdd, RawSchema)

          WriterHelper.write(rawDataFrame, rawData.writerOptions, Map.empty[String, String], outputs)
        }
      })
  }
} 
Example 73
Source File: WriterHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import akka.event.slf4j.SLF4JLogging
import com.stratio.sparta.driver.schema.SchemaHelper
import com.stratio.sparta.sdk.pipeline.output.Output
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

import scala.util.{Failure, Success, Try}

object WriterHelper extends SLF4JLogging {

  def write(dataFrame: DataFrame,
            writerOptions: WriterOptions,
            extraSaveOptions: Map[String, String],
            outputs: Seq[Output]): DataFrame = {
    val saveOptions = extraSaveOptions ++
      writerOptions.tableName.fold(Map.empty[String, String]) { outputTableName =>
        Map(Output.TableNameKey -> outputTableName)
      } ++
      writerOptions.partitionBy.fold(Map.empty[String, String]) { partition =>
        Map(Output.PartitionByKey -> partition)
      } ++
      writerOptions.primaryKey.fold(Map.empty[String, String]) { key =>
        Map(Output.PrimaryKey -> key)
      }
    val outputTableName = saveOptions.getOrElse(Output.TableNameKey, "undefined")
    val autoCalculatedFieldsDf = DataFrameModifierHelper.applyAutoCalculateFields(dataFrame,
        writerOptions.autoCalculateFields,
        StructType(dataFrame.schema.fields ++ SchemaHelper.getStreamWriterPkFieldsMetadata(writerOptions.primaryKey)))

    writerOptions.outputs.foreach(outputName =>
      outputs.find(output => output.name == outputName) match {
        case Some(outputWriter) => Try {
          outputWriter.save(autoCalculatedFieldsDf, writerOptions.saveMode, saveOptions)
        } match {
          case Success(_) =>
            log.debug(s"Data stored in $outputTableName")
          case Failure(e) =>
            log.error(s"Something goes wrong. Table: $outputTableName")
            log.error(s"Schema. ${autoCalculatedFieldsDf.schema}")
            log.error(s"Head element. ${autoCalculatedFieldsDf.head}")
            log.error(s"Error message : ${e.getMessage}")
        }
        case None => log.error(s"The output added : $outputName not match in the outputs")
      })
    autoCalculatedFieldsDf
  }

} 
Example 74
Source File: TriggerWriterHelper.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.writer

import akka.event.slf4j.SLF4JLogging
import org.apache.spark.sql.{DataFrame, Row}
import com.stratio.sparta.driver.exception.DriverException
import com.stratio.sparta.driver.factory.SparkContextFactory
import com.stratio.sparta.driver.schema.SchemaHelper
import com.stratio.sparta.driver.step.Trigger
import com.stratio.sparta.sdk.pipeline.output.Output
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.dstream.DStream

import scala.util.{Failure, Success, Try}

object TriggerWriterHelper extends SLF4JLogging {

  def writeStream(triggers: Seq[Trigger],
                  inputTableName: String,
                  outputs: Seq[Output],
                  streamData: DStream[Row],
                  schema: StructType): Unit = {
    streamData.foreachRDD(rdd => {
      val parsedDataFrame = SparkContextFactory.sparkSessionInstance.createDataFrame(rdd, schema)

      writeTriggers(parsedDataFrame, triggers, inputTableName, outputs)
    })
  }

  //scalastyle:off
  def writeTriggers(dataFrame: DataFrame,
                    triggers: Seq[Trigger],
                    inputTableName: String,
                    outputs: Seq[Output]): Unit = {
    val sparkSession = dataFrame.sparkSession
    if (triggers.nonEmpty && isCorrectTableName(inputTableName)) {
      if (!sparkSession.catalog.tableExists(inputTableName)) {
        dataFrame.createOrReplaceTempView(inputTableName)
        log.debug(s"Registering temporal table in Spark with name: $inputTableName")
      }
      val tempTables = triggers.flatMap(trigger => {
        log.debug(s"Executing query in Spark: ${trigger.sql}")
        val queryDf = Try(sparkSession.sql(trigger.sql)) match {
          case Success(sqlResult) => sqlResult
          case Failure(exception: org.apache.spark.sql.AnalysisException) =>
            log.warn("Warning running analysis in Catalyst in the query ${trigger.sql} in trigger ${trigger.name}",
              exception.message)
            throw DriverException(exception.getMessage, exception)
          case Failure(exception) =>
            log.warn(s"Warning running sql in the query ${trigger.sql} in trigger ${trigger.name}", exception.getMessage)
            throw DriverException(exception.getMessage, exception)
        }
        val extraOptions = Map(Output.TableNameKey -> trigger.name)

        if (!queryDf.rdd.isEmpty()) {
          val autoCalculatedFieldsDf = WriterHelper.write(queryDf, trigger.writerOptions, extraOptions, outputs)
          if (isCorrectTableName(trigger.name) && !sparkSession.catalog.tableExists(trigger.name)) {
            autoCalculatedFieldsDf.createOrReplaceTempView(trigger.name)
            log.debug(s"Registering temporal table in Spark with name: ${trigger.name}")
          }
          else log.warn(s"The trigger ${trigger.name} have incorrect name, is impossible to register as temporal table")

          Option(trigger.name)
        } else None
      })
      tempTables.foreach(tableName =>
        if (isCorrectTableName(tableName) && sparkSession.catalog.tableExists(tableName)) {
          sparkSession.catalog.dropTempView(tableName)
          log.debug(s"Dropping temporal table in Spark with name: $tableName")
        } else log.debug(s"Impossible to drop table in Spark with name: $tableName"))

      if (isCorrectTableName(inputTableName) && sparkSession.catalog.tableExists(inputTableName)) {
        sparkSession.catalog.dropTempView(inputTableName)
        log.debug(s"Dropping temporal table in Spark with name: $inputTableName")
      } else log.debug(s"Impossible to drop table in Spark: $inputTableName")
    } else {
      if (triggers.nonEmpty && !isCorrectTableName(inputTableName))
        log.warn(s"Incorrect table name $inputTableName and the triggers could have errors and not have been " +
          s"executed")
    }
  }

  //scalastyle:on

  private[driver] def isCorrectTableName(tableName: String): Boolean =
    tableName.nonEmpty && tableName != "" &&
      tableName.toLowerCase != "select" &&
      tableName.toLowerCase != "project" &&
      !tableName.contains("-") && !tableName.contains("*") && !tableName.contains("/")
} 
Example 75
Source File: TriggerStage.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.stage

import com.stratio.sparta.driver.step.Trigger
import com.stratio.sparta.driver.writer.{TriggerWriterHelper, WriterOptions}
import com.stratio.sparta.sdk.pipeline.output.Output
import com.stratio.sparta.sdk.utils.AggregationTime
import com.stratio.sparta.serving.core.models.policy.PhaseEnum
import com.stratio.sparta.serving.core.models.policy.trigger.TriggerModel
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.Milliseconds
import org.apache.spark.streaming.dstream.DStream

trait TriggerStage extends BaseStage {
  this: ErrorPersistor =>

  def triggersStreamStage(initSchema: StructType,
                          inputData: DStream[Row],
                          outputs: Seq[Output],
                          window: Long): Unit = {
    val triggersStage = triggerStage(policy.streamTriggers)
    val errorMessage = s"Something gone wrong executing the triggers stream for: ${policy.input.get.name}."
    val okMessage = s"Triggers Stream executed correctly."
    generalTransformation(PhaseEnum.TriggerStream, okMessage, errorMessage) {
      triggersStage
        .groupBy(trigger => (trigger.overLast, trigger.computeEvery))
        .foreach { case ((overLast, computeEvery), triggers) =>
          val groupedData = (overLast, computeEvery) match {
            case (None, None) => inputData
            case (Some(overL), Some(computeE))
              if (AggregationTime.parseValueToMilliSeconds(overL) == window) &&
                (AggregationTime.parseValueToMilliSeconds(computeE) == window) => inputData
            case _ => inputData.window(
              Milliseconds(
                overLast.fold(window) { over => AggregationTime.parseValueToMilliSeconds(over) }),
              Milliseconds(
                computeEvery.fold(window) { computeEvery => AggregationTime.parseValueToMilliSeconds(computeEvery) }))
          }
          TriggerWriterHelper.writeStream(triggers, streamTemporalTable(policy.streamTemporalTable), outputs,
            groupedData, initSchema)
        }
    }
  }

  def triggerStage(triggers: Seq[TriggerModel]): Seq[Trigger] =
    triggers.map(trigger => createTrigger(trigger))

  private[driver] def createTrigger(trigger: TriggerModel): Trigger = {
    val okMessage = s"Trigger: ${trigger.name} created correctly."
    val errorMessage = s"Something gone wrong creating the trigger: ${trigger.name}. Please re-check the policy."
    generalTransformation(PhaseEnum.Trigger, okMessage, errorMessage) {
      Trigger(
        trigger.name,
        trigger.sql,
        trigger.overLast,
        trigger.computeEvery,
        WriterOptions(
          trigger.writer.outputs,
          trigger.writer.saveMode,
          trigger.writer.tableName,
          getAutoCalculatedFields(trigger.writer.autoCalculatedFields),
          trigger.writer.primaryKey,
          trigger.writer.partitionBy
        ),
        trigger.configuration)
    }
  }

  private[driver] def streamTemporalTable(policyTableName: Option[String]): String =
    policyTableName.flatMap(tableName => if (tableName.nonEmpty) Some(tableName) else None)
      .getOrElse("stream")

} 
Example 76
Source File: ParserStage.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.stage

import java.io.Serializable

import akka.event.slf4j.SLF4JLogging
import com.stratio.sparta.driver.writer.{TransformationsWriterHelper, WriterOptions}
import com.stratio.sparta.sdk.pipeline.output.Output
import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.serving.core.constants.AppConstant
import com.stratio.sparta.serving.core.models.policy.{PhaseEnum, TransformationModel}
import com.stratio.sparta.serving.core.utils.ReflectionUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.dstream.DStream

import scala.util.{Failure, Success, Try}

trait ParserStage extends BaseStage {
  this: ErrorPersistor =>

  def parserStage(refUtils: ReflectionUtils,
                  schemas: Map[String, StructType]): (Seq[Parser], Option[WriterOptions]) =
    (policy.transformations.get.transformationsPipe.map(parser => createParser(parser, refUtils, schemas)),
      policy.transformations.get.writer.map(writer => WriterOptions(
        writer.outputs,
        writer.saveMode,
        writer.tableName,
        getAutoCalculatedFields(writer.autoCalculatedFields),
        writer.partitionBy,
        writer.primaryKey
      )))

  private[driver] def createParser(model: TransformationModel,
                           refUtils: ReflectionUtils,
                           schemas: Map[String, StructType]): Parser = {
    val classType = model.configuration.getOrElse(AppConstant.CustomTypeKey, model.`type`).toString
    val errorMessage = s"Something gone wrong creating the parser: $classType. Please re-check the policy."
    val okMessage = s"Parser: $classType created correctly."
    generalTransformation(PhaseEnum.Parser, okMessage, errorMessage) {
      val outputFieldsNames = model.outputFieldsTransformed.map(_.name)
      val schema = schemas.getOrElse(model.order.toString, throw new Exception("Can not find transformation schema"))
      refUtils.tryToInstantiate[Parser](classType + Parser.ClassSuffix, (c) =>
        c.getDeclaredConstructor(
          classOf[Integer],
          classOf[Option[String]],
          classOf[Seq[String]],
          classOf[StructType],
          classOf[Map[String, Serializable]])
          .newInstance(model.order, model.inputField, outputFieldsNames, schema, model.configuration)
          .asInstanceOf[Parser])
    }
  }
}

object ParserStage extends SLF4JLogging {

  def executeParsers(row: Row, parsers: Seq[Parser]): Seq[Row] =
    if (parsers.size == 1) parseEvent(row, parsers.head)
    else parseEvent(row, parsers.head).flatMap(eventParsed => executeParsers(eventParsed, parsers.drop(1)))

  def parseEvent(row: Row, parser: Parser): Seq[Row] =
    Try {
      parser.parse(row)
    } match {
      case Success(eventParsed) =>
        eventParsed
      case Failure(exception) =>
        val error = s"Failure[Parser]: ${row.mkString(",")} | Message: ${exception.getLocalizedMessage}" +
          s" | Parser: ${parser.getClass.getSimpleName}"
        log.error(error, exception)
        Seq.empty[Row]
    }

  def applyParsers(input: DStream[Row],
                   parsers: Seq[Parser],
                   schema: StructType,
                   outputs: Seq[Output],
                   writerOptions: Option[WriterOptions]): DStream[Row] = {
    val transformedData = if (parsers.isEmpty) input
    else input.flatMap(row => executeParsers(row, parsers))

    writerOptions.foreach(options =>
      TransformationsWriterHelper.writeTransformations(transformedData, schema, outputs, options))
    transformedData
  }
} 
Example 77
Source File: CubeMakerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.cube

import java.sql.Timestamp

import com.github.nscala_time.time.Imports._
import com.stratio.sparta.driver.step.{Cube, CubeOperations, Trigger}
import com.stratio.sparta.driver.writer.WriterOptions
import com.stratio.sparta.plugin.default.DefaultField
import com.stratio.sparta.plugin.cube.field.datetime.DateTimeField
import com.stratio.sparta.plugin.cube.operator.count.CountOperator
import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionValue, DimensionValuesTime, InputFields}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.utils.AggregationTime
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.streaming.TestSuiteBase
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CubeMakerTest extends TestSuiteBase {

  val PreserverOrder = false

  
  def getEventOutput(timestamp: Timestamp, millis: Long):
  Seq[Seq[(DimensionValuesTime, InputFields)]] = {
    val dimensionString = Dimension("dim1", "eventKey", "identity", new DefaultField)
    val dimensionTime = Dimension("minute", "minute", "minute", new DateTimeField)
    val dimensionValueString1 = DimensionValue(dimensionString, "value1")
    val dimensionValueString2 = dimensionValueString1.copy(value = "value2")
    val dimensionValueString3 = dimensionValueString1.copy(value = "value3")
    val dimensionValueTs = DimensionValue(dimensionTime, timestamp)
    val tsMap = Row(timestamp)
    val valuesMap1 = InputFields(Row("value1", timestamp), 1)
    val valuesMap2 = InputFields(Row("value2", timestamp), 1)
    val valuesMap3 = InputFields(Row("value3", timestamp), 1)

    Seq(Seq(
      (DimensionValuesTime("cubeName", Seq(dimensionValueString1, dimensionValueTs)), valuesMap1),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString2, dimensionValueTs)), valuesMap2),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString3, dimensionValueTs)), valuesMap3)
    ))
  }
} 
Example 78
Source File: Parser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.transformation

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.properties.{CustomProperties, Parameterizable}
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField, StructType}

import scala.util.{Failure, Success, Try}

abstract class Parser(order: Integer,
                      inputField: Option[String],
                      outputFields: Seq[String],
                      schema: StructType,
                      properties: Map[String, JSerializable])
  extends Parameterizable(properties) with Ordered[Parser] with CustomProperties {

  val customKey = "transformationOptions"
  val customPropertyKey = "transformationOptionsKey"
  val customPropertyValue = "transformationOptionsValue"
  val propertiesWithCustom = properties ++ getCustomProperties

  val outputFieldsSchema = schema.fields.filter(field => outputFields.contains(field.name))

  val inputFieldRemoved = Try(propertiesWithCustom.getBoolean("removeInputField")).getOrElse(false)

  val inputFieldIndex = inputField match {
    case Some(field) => Try(schema.fieldIndex(field)).getOrElse(0)
    case None => 0
  }

  val whenErrorDo = Try(WhenError.withName(propertiesWithCustom.getString("whenError")))
    .getOrElse(WhenError.Error)

  def parse(data: Row): Seq[Row]

  def getOrder: Integer = order

  def checkFields(keyMap: Map[String, JSerializable]): Map[String, JSerializable] =
    keyMap.flatMap(key => if (outputFields.contains(key._1)) Some(key) else None)

  def compare(that: Parser): Int = this.getOrder.compareTo(that.getOrder)

  //scalastyle:off
  def returnWhenError(exception: Exception): Null =
  whenErrorDo match {
    case WhenError.Null => null
    case _ => throw exception
  }

  //scalastyle:on

  def parseToOutputType(outSchema: StructField, inputValue: Any): Any =
    Try(TypeOp.transformValueByTypeOp(outSchema.dataType, inputValue.asInstanceOf[Any]))
      .getOrElse(returnWhenError(new IllegalStateException(
        s"Error parsing to output type the value: ${inputValue.toString}")))

  def returnData(newData: Try[Seq[_]], prevData: Seq[_]): Seq[Row] =
    newData match {
      case Success(data) => Seq(Row.fromSeq(prevData ++ data))
      case Failure(e) => whenErrorDo match {
        case WhenError.Discard => Seq.empty[Row]
        case _ => throw e
      }
    }

  def returnData(newData: Try[Row], prevData: Row): Seq[Row] =
    newData match {
      case Success(data) => Seq(Row.merge(prevData, data))
      case Failure(e) => whenErrorDo match {
        case WhenError.Discard => Seq.empty[Row]
        case _ => throw e
      }
    }

  def removeIndex(row: Seq[_], inputFieldIndex: Int): Seq[_] = if (row.size < inputFieldIndex) row
    else row.take(inputFieldIndex) ++ row.drop(inputFieldIndex + 1)

  def removeInputField(row: Row): Seq[_] = {
    if (inputFieldRemoved && inputField.isDefined)
      removeIndex(row.toSeq, inputFieldIndex)
    else
      row.toSeq
  }


}

object Parser {

  final val ClassSuffix = "Parser"
  final val DefaultOutputType = "string"
  final val TypesFromParserClass = Map("datetime" -> "timestamp")
} 
Example 79
Source File: ParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.transformation

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ParserTest extends WordSpec with Matchers {

  "Parser" should {

    val parserTest = new ParserMock(
      1,
      Some("input"),
      Seq("output"),
      StructType(Seq(StructField("some", StringType))),
      Map()
    )

    "Order must be " in {
      val expected = 1
      val result = parserTest.getOrder
      result should be(expected)
    }

    "Parse must be " in {
      val event = Row("value")
      val expected = Seq(event)
      val result = parserTest.parse(event)
      result should be(expected)
    }

    "checked fields not be contained in outputs must be " in {
      val keyMap = Map("field" -> "value")
      val expected = Map()
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "checked fields are contained in outputs must be " in {
      val keyMap = Map("output" -> "value")
      val expected = keyMap
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "classSuffix must be " in {
      val expected = "Parser"
      val result = Parser.ClassSuffix
      result should be(expected)
    }
  }
} 
Example 80
Source File: RedisOutput.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.redis

import java.io.Serializable

import com.stratio.sparta.plugin.output.redis.dao.AbstractRedisDAO
import com.stratio.sparta.sdk.pipeline.output.Output._
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row}


class RedisOutput(name: String, properties: Map[String, Serializable])
  extends Output(name, properties) with AbstractRedisDAO with Serializable {

  override val hostname = properties.getString("hostname", DefaultRedisHostname)
  override val port = properties.getString("port", DefaultRedisPort).toInt

  override def save(dataFrame: DataFrame, saveMode: SaveModeEnum.Value, options: Map[String, String]): Unit = {
    val tableName = getTableNameFromOptions(options)
    val schema = dataFrame.schema

    validateSaveMode(saveMode)

    dataFrame.foreachPartition{ rowList =>
      rowList.foreach{ row =>
        val valuesList = getValuesList(row,schema.fieldNames)
        val hashKey = getHashKeyFromRow(valuesList, schema)
        getMeasuresFromRow(valuesList, schema).foreach { case (measure, value) =>
          hset(hashKey, measure.name, value)
        }
      }
    }
  }

  def getHashKeyFromRow(valuesList: Seq[(String, String)], schema: StructType): String =
    valuesList.flatMap{ case (key, value) =>
      val fieldSearch = schema.fields.find(structField =>
        structField.metadata.contains(Output.PrimaryKeyMetadataKey) && structField.name == key)

      fieldSearch.map(structField => s"${structField.name}$IdSeparator$value")
    }.mkString(IdSeparator)

  def getMeasuresFromRow(valuesList: Seq[(String, String)], schema: StructType): Seq[(StructField, String)] =
    valuesList.flatMap{ case (key, value) =>
      val fieldSearch = schema.fields.find(structField =>
          structField.metadata.contains(Output.MeasureMetadataKey) &&
          structField.name == key)
      fieldSearch.map(field => (field, value))
    }

  def getValuesList(row: Row, fieldNames: Array[String]): Seq[(String, String)] =
    fieldNames.zip(row.toSeq).map{ case (key, value) => (key, value.toString)}.toSeq
} 
Example 81
Source File: MongoDbOutput.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.mongodb

import java.io.{Serializable => JSerializable}

import com.stratio.datasource.mongodb.config.MongodbConfig
import com.stratio.sparta.sdk.pipeline.output.Output._
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

class MongoDbOutput(name: String, properties: Map[String, JSerializable]) extends Output(name, properties) {

  val DefaultHost = "localhost"
  val DefaultPort = "27017"
  val MongoDbSparkDatasource = "com.stratio.datasource.mongodb"
  val hosts = getConnectionConfs("hosts", "host", "port")
  val dbName = properties.getString("dbName", "sparta")

  override def save(dataFrame: DataFrame, saveMode: SaveModeEnum.Value, options: Map[String, String]): Unit = {
    val tableName = getTableNameFromOptions(options)
    val primaryKeyOption = getPrimaryKeyOptions(options)
    val dataFrameOptions = getDataFrameOptions(tableName, dataFrame.schema, saveMode, primaryKeyOption)

    validateSaveMode(saveMode)

    dataFrame.write
      .format(MongoDbSparkDatasource)
      .mode(getSparkSaveMode(saveMode))
      .options(dataFrameOptions ++ getCustomProperties)
      .save()
  }

  private def getDataFrameOptions(tableName: String,
                                  schema: StructType,
                                  saveMode: SaveModeEnum.Value,
                                  primaryKey: Option[String]
                                 ): Map[String, String] =
    Map(
      MongodbConfig.Host -> hosts,
      MongodbConfig.Database -> dbName,
      MongodbConfig.Collection -> tableName
    ) ++ {
      saveMode match {
        case SaveModeEnum.Upsert => getUpdateFieldsOptions(schema, primaryKey)
        case _ => Map.empty[String, String]
      }
    }

  private def getUpdateFieldsOptions(schema: StructType, primaryKey: Option[String]): Map[String, String] = {

    val updateFields = primaryKey.getOrElse(
      schema.fields.filter(stField =>
        stField.metadata.contains(Output.PrimaryKeyMetadataKey)).map(_.name).mkString(",")
    )

    Map(MongodbConfig.UpdateFields -> updateFields)
  }

  private def getConnectionConfs(key: String, firstJsonItem: String, secondJsonItem: String): String = {
    val conObj = properties.getMapFromJsoneyString(key)
    conObj.map(c => {
      val host = c.getOrElse(firstJsonItem, DefaultHost)
      val port = c.getOrElse(secondJsonItem, DefaultPort)
      s"$host:$port"
    }).mkString(",")
  }
} 
Example 82
Source File: LastValueOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsAny}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

import scala.util.Try

class LastValueOperator(name: String,
                        val schema: StructType,
                        properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsAny with Associative {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Any

  override def processReduce(values: Iterable[Option[Any]]): Option[Any] =
    Try(Option(values.flatten.last)).getOrElse(None)

  def associativity(values: Iterable[(String, Option[Any])]): Option[Any] = {
    val newValues = extractValues(values, Option(Operator.NewValuesKey))
    val lastValue = if(newValues.nonEmpty) newValues
    else extractValues(values, Option(Operator.OldValuesKey))

    Try(Option(transformValueByTypeOp(returnType, lastValue.last))).getOrElse(None)
  }
} 
Example 83
Source File: StddevOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import java.io.{Serializable => JSerializable}

import breeze.stats._
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

class StddevOperator(name: String,
                     val schema: StructType,
                     properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsNumber {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    val valuesFiltered = getDistinctValues(values.flatten)
    valuesFiltered.size match {
      case (nz) if (nz != 0) =>
        Some(transformValueByTypeOp(returnType, stddev(valuesFiltered.map(value =>
          TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double]))))
      case _ => Some(Operator.Zero.toDouble)
    }
  }
} 
Example 84
Source File: CountOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.count

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import scala.util.Try

class CountOperator(name: String, val schema: StructType, properties: Map[String, JSerializable])
  extends Operator(name, schema, properties) with Associative {

  val distinctFields = parseDistinctFields

  override val defaultTypeOperation = TypeOp.Long

  override def processMap(inputFieldsValues: Row): Option[Any] = {
    applyFilters(inputFieldsValues).flatMap(filteredFields => distinctFields match {
      case None => Option(CountOperator.One.toLong)
      case Some(fields) => Option(fields.map(field => filteredFields.getOrElse(field, CountOperator.NullValue))
        .mkString(Operator.UnderscoreSeparator).toString)
    })
  }

  override def processReduce(values: Iterable[Option[Any]]): Option[Long] = {
    Try {
      val longList = distinctFields match {
        case None => values.flatten.map(value => value.asInstanceOf[Number].longValue())
        case Some(fields) => values.flatten.toList.distinct.map(value => CountOperator.One.toLong)
      }
      Option(longList.sum)
    }.getOrElse(Option(Operator.Zero.toLong))
  }

  def associativity(values: Iterable[(String, Option[Any])]): Option[Long] = {
    val newValues = extractValues(values, None)
      .map(value => TypeOp.transformValueByTypeOp(TypeOp.Long, value).asInstanceOf[Long]).sum

    Try(Option(transformValueByTypeOp(returnType, newValues)))
      .getOrElse(Option(Operator.Zero.toLong))
  }

  //FIXME: We should refactor this code
  private def parseDistinctFields: Option[Seq[String]] = {
    val distinct = properties.getString("distinctFields", None)
    if (distinct.isDefined && !distinct.get.isEmpty)
      Option(distinct.get.split(Operator.UnderscoreSeparator))
    else None
  }
}

object CountOperator {

  final val NullValue = "None"
  final val One = 1
} 
Example 85
Source File: MedianOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import java.io.{Serializable => JSerializable}

import breeze.linalg._
import breeze.stats._
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

class MedianOperator(name: String,
                     val schema: StructType,
                     properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsNumber {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    val valuesFiltered = getDistinctValues(values.flatten)
    valuesFiltered.size match {
      case (nz) if (nz != 0) => Some(transformValueByTypeOp(returnType,
        median(DenseVector(valuesFiltered.map(_.asInstanceOf[Number].doubleValue()).toArray))))
      case _ => Some(Operator.Zero.toDouble)
    }
  }
} 
Example 86
Source File: ModeOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.types.StructType

import scala.util.Try
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsAny}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}

class ModeOperator(name: String,
                   val schema: StructType,
                   properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsAny {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.ArrayString

  override def processReduce(values: Iterable[Option[Any]]): Option[Any] = {
    val tupla = values.groupBy(x => x).mapValues(_.size)
    if (tupla.nonEmpty) {
      val max = tupla.values.max
      Try(Some(transformValueByTypeOp(returnType, tupla.filter(_._2 == max).flatMap(tuple => tuple._1)))).get
    } else Some(List())
  }
} 
Example 87
Source File: RangeOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

class RangeOperator(name: String,
                    val schema: StructType,
                    properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsNumber {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    val valuesFiltered = getDistinctValues(values.flatten)
    valuesFiltered.size match {
      case (nz) if nz != 0 =>
        val valuesConverted = valuesFiltered.map(value =>
          TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double])
        Some(transformValueByTypeOp(returnType, valuesConverted.max - valuesConverted.min))
      case _ => Some(Operator.Zero.toDouble)
    }
  }
} 
Example 88
Source File: AccumulatorOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsAny}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

import scala.util.Try

class AccumulatorOperator(name: String,
                          val schema: StructType,
                          properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsAny with Associative {

  final val Separator = " "

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.ArrayString

  override def processReduce(values: Iterable[Option[Any]]): Option[Any] =
    Try(Option(values.flatten.flatMap(value => {
      value match {
        case value if value.isInstanceOf[Seq[Any]] => value.asInstanceOf[Seq[Any]].map(_.toString)
        case _ => Seq(TypeOp.transformValueByTypeOp(TypeOp.String, value).asInstanceOf[String])
      }
    }))).getOrElse(None)

  def associativity(values: Iterable[(String, Option[Any])]): Option[Any] = {
    val newValues = getDistinctValues(extractValues(values, None).asInstanceOf[Seq[Seq[String]]].flatten)

    Try(Option(transformValueByTypeOp(returnType, newValues))).getOrElse(Option(Seq()))
  }
} 
Example 89
Source File: FirstValueOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsAny}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

import scala.util.Try

class FirstValueOperator(name: String,
                         val schema: StructType,
                         properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsAny with Associative {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Any

  override def processReduce(values: Iterable[Option[Any]]): Option[Any] =
    Try(Option(values.flatten.head)).getOrElse(None)

  def associativity(values: Iterable[(String, Option[Any])]): Option[Any] = {
    val oldValues = extractValues(values, Option(Operator.OldValuesKey))
    val firstValue = if(oldValues.nonEmpty) oldValues
    else extractValues(values, Option(Operator.NewValuesKey))

    Try(Option(transformValueByTypeOp(returnType, firstValue.head))).getOrElse(None)
  }
} 
Example 90
Source File: MeanAssociativeOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.apache.spark.sql.types.StructType

import scala.util.Try

class MeanAssociativeOperator(name: String, val schema: StructType, properties: Map[String, JSerializable])
  extends Operator(name, schema, properties) with OperatorProcessMapAsNumber with Associative {

  private val SumKey = "sum"
  private val MeanKey = "mean"
  private val CountKey = "count"

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.MapStringDouble

  override def processReduce(values: Iterable[Option[Any]]): Option[Seq[Double]] = {
    Try(Option(getDistinctValues(values.flatten.flatMap(value => {
      value match {
        case value if value.isInstanceOf[Seq[Double]] => value.asInstanceOf[Seq[Double]]
        case _ => List(TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double])
      }
    })))).getOrElse(Some(Seq.empty[Double]))
  }

  def associativity(values: Iterable[(String, Option[Any])]): Option[Map[String, Any]] = {
    val oldValues = extractValues(values, Option(Operator.OldValuesKey))
      .map(_.asInstanceOf[Map[String, Double]]).headOption
    val newValues =  extractValues(values, Option(Operator.NewValuesKey)).flatMap(value => {
      value match {
        case value if value.isInstanceOf[Seq[Double]] => value.asInstanceOf[Seq[Double]]
        case _ => List(TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double])
      }
    }).toList

    val returnValues = if(newValues.nonEmpty) {
      val oldCount = oldValues.fold(0d) { case oldV => oldV.getOrElse(CountKey, 0d)}
      val oldSum = oldValues.fold(0d) { case oldV => oldV.getOrElse(SumKey, 0d)}
      val calculatedSum = oldSum + newValues.sum
      val calculatedCount = oldCount + newValues.size.toDouble
      val calculatedMean = if (calculatedCount != 0d) calculatedSum / calculatedCount else 0d

      Map(SumKey -> calculatedSum, CountKey -> calculatedCount, MeanKey -> calculatedMean)
    } else oldValues.getOrElse(Map(SumKey -> 0d, CountKey -> 0d, MeanKey -> 0d))

    Try(Option(TypeOp.transformValueByTypeOp(returnType, returnValues))).getOrElse(Option(Map.empty[String, Double]))
  }

} 
Example 91
Source File: MeanOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

class MeanOperator(name: String, val schema: StructType, properties: Map[String, JSerializable])
  extends Operator(name, schema, properties) with OperatorProcessMapAsNumber {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    val distinctValues = getDistinctValues(values.flatten)
    distinctValues.size match {
      case (nz) if nz != 0 => Some(transformValueByTypeOp(returnType,
        distinctValues.map(_.asInstanceOf[Number].doubleValue()).sum / distinctValues.size))
      case _ => Some(Operator.Zero.toDouble)
    }
  }
} 
Example 92
Source File: EntityCountOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.apache.spark.sql.types.StructType

import scala.util.Try

class EntityCountOperator(name: String,
                          schema: StructType,
                          properties: Map[String, JSerializable])
  extends OperatorEntityCount(name, schema, properties) with Associative {

  final val Some_Empty = Some(Map("" -> 0L))

  override val defaultTypeOperation = TypeOp.MapStringLong

  override def processReduce(values: Iterable[Option[Any]]): Option[Seq[String]] =
    Try(Option(values.flatten.flatMap(_.asInstanceOf[Seq[String]]).toSeq))
      .getOrElse(None)

  def associativity(values: Iterable[(String, Option[Any])]): Option[Map[String, Long]] = {
    val oldValues = extractValues(values, Option(Operator.OldValuesKey))
      .flatMap(_.asInstanceOf[Map[String, Long]]).toList
    val newValues = applyCount(extractValues(values, Option(Operator.NewValuesKey))
      .flatMap(value => {
        value match {
          case value if value.isInstanceOf[Seq[String]] => value.asInstanceOf[Seq[String]]
          case _ => List(TypeOp.transformValueByTypeOp(TypeOp.String, value).asInstanceOf[String])
        }
      }).toList).toList
    val wordCounts = applyCountMerge(oldValues ++ newValues)

    Try(Option(transformValueByTypeOp(returnType, wordCounts)))
      .getOrElse(Option(Map()))
  }

  private def applyCount(values: List[String]): Map[String, Long] =
    values.groupBy((word: String) => word).mapValues(_.length.toLong)

  private def applyCountMerge(values: List[(String, Long)]): Map[String, Long] =
    values.groupBy { case (word, count) => word }.mapValues {
      listValues => listValues.map { case (key, value) => value }.sum
    }
} 
Example 93
Source File: OperatorEntityCount.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

abstract class OperatorEntityCount(name: String,
                                   val schema: StructType,
                                   properties: Map[String, JSerializable])
  extends Operator(name, schema, properties) {

  val split = if (properties.contains("split")) Some(properties.getString("split")) else None

  val replaceRegex =
    if (properties.contains("replaceRegex")) Some(properties.getString("replaceRegex")) else None

  override def processMap(inputFieldsValues: Row): Option[Seq[String]] = {
    if (inputField.isDefined && schema.fieldNames.contains(inputField.get))
      applyFilters(inputFieldsValues).flatMap(filteredFields => filteredFields.get(inputField.get).map(applySplitters))
    else None
  }

  private def applySplitters(value: Any): Seq[String] = {
    val replacedValue = applyReplaceRegex(value.toString)
    if (split.isDefined) replacedValue.split(split.get) else Seq(replacedValue)
  }

  private def applyReplaceRegex(value: String): String =
    replaceRegex match {
      case Some(regex) => value.replaceAll(regex, "")
      case None => value
    }
} 
Example 94
Source File: VarianceOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.variance

import java.io.{Serializable => JSerializable}

import breeze.stats._
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

class VarianceOperator(name: String,
                       val schema: StructType,
                       properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsNumber {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    val valuesFiltered = getDistinctValues(values.flatten)
    valuesFiltered.size match {
      case (nz) if nz != 0 =>
        Some(transformValueByTypeOp(returnType, variance(valuesFiltered.map(value =>
          TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double]))))
      case _ => Some(Operator.Zero.toDouble)
    }
  }
} 
Example 95
Source File: SumOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.sum

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsNumber}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.apache.spark.sql.types.StructType

import scala.util.Try

class SumOperator(name: String,
                  val schema: StructType,
                  properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsNumber with Associative {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.Double

  override def processReduce(values: Iterable[Option[Any]]): Option[Double] = {
    Try(Option(getDistinctValues(values.flatten.map(value =>
      TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double])).sum))
      .getOrElse(Some(Operator.Zero.toDouble))
  }

  def associativity(values: Iterable[(String, Option[Any])]): Option[Double] = {
    val newValues = extractValues(values, None)

    Try(Option(transformValueByTypeOp(returnType, newValues.map(value =>
      TypeOp.transformValueByTypeOp(TypeOp.Double, value).asInstanceOf[Double]).sum)))
      .getOrElse(Some(Operator.Zero.toDouble))
  }
} 
Example 96
Source File: FullTextOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.fullText

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.{Associative, Operator, OperatorProcessMapAsAny}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.{_}
import org.apache.spark.sql.types.StructType

import scala.util.Try

class FullTextOperator(name: String,
                       val schema: StructType,
                       properties: Map[String, JSerializable]) extends Operator(name, schema, properties)
with OperatorProcessMapAsAny with Associative {

  val inputSchema = schema

  override val defaultTypeOperation = TypeOp.String

  override def processReduce(values: Iterable[Option[Any]]): Option[String] = {
    Try(Option(values.flatten.map(_.toString).mkString(Operator.SpaceSeparator)))
      .getOrElse(Some(Operator.EmptyString))
  }

  def associativity(values: Iterable[(String, Option[Any])]): Option[String] = {
    val newValues = extractValues(values, None).map(_.toString).mkString(Operator.SpaceSeparator)

    Try(Option(transformValueByTypeOp(returnType, newValues)))
      .getOrElse(Some(Operator.EmptyString))
  }
} 
Example 97
Source File: TotalEntityCountOperator.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.totalEntityCount

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.plugin.cube.operator.entityCount.OperatorEntityCount
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.pipeline.aggregation.operator.Associative
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.apache.spark.sql.types.StructType

import scala.util.Try

class TotalEntityCountOperator(name: String,
                               schema: StructType,
                               properties: Map[String, JSerializable])
  extends OperatorEntityCount(name, schema, properties) with Associative {

  final val Some_Empty = Some(0)

  override val defaultTypeOperation = TypeOp.Int

  override def processReduce(values: Iterable[Option[Any]]): Option[Int] =
    Try(Option(values.flatten.map(value => {
      value match {
        case value if value.isInstanceOf[Seq[_]] => getDistinctValues(value.asInstanceOf[Seq[_]]).size
        case _ => TypeOp.transformValueByTypeOp(TypeOp.Int, value).asInstanceOf[Int]
      }
    }).sum)).getOrElse(Some_Empty)

  def associativity(values: Iterable[(String, Option[Any])]): Option[Int] = {
    val newValues =
      extractValues(values, None).map(value => TypeOp.transformValueByTypeOp(TypeOp.Int, value).asInstanceOf[Int]).sum

    Try(Option(transformValueByTypeOp(returnType, newValues)))
      .getOrElse(Some_Empty)
  }
} 
Example 98
Source File: GeoParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.geo

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.pipeline.transformation.{Parser, WhenError}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import scala.util.{Failure, Success, Try}

class GeoParser(
                 order: Integer,
                 inputField: Option[String],
                 outputFields: Seq[String],
                 schema: StructType,
                 properties: Map[String, JSerializable]
               ) extends Parser(order, inputField, outputFields, schema, properties) {

  val defaultLatitudeField = "latitude"
  val defaultLongitudeField = "longitude"
  val separator = "__"

  val latitudeField = properties.getOrElse("latitude", defaultLatitudeField).toString
  val longitudeField = properties.getOrElse("longitude", defaultLongitudeField).toString

  def parse(row: Row): Seq[Row] = {
    val newData = Try {
      val geoValue = geoField(getLatitude(row), getLongitude(row))
      outputFields.map(outputField => {
        val outputSchemaValid = outputFieldsSchema.find(field => field.name == outputField)
        outputSchemaValid match {
          case Some(outSchema) =>
            TypeOp.transformValueByTypeOp(outSchema.dataType, geoValue)
          case None =>
            returnWhenError(
              throw new IllegalStateException(s"Impossible to parse outputField: $outputField in the schema"))
        }
      })
    }

    returnData(newData, removeInputField(row))
  }

  private def getLatitude(row: Row): String = {
    val latitude = Try(row.get(schema.fieldIndex(latitudeField)))
      .getOrElse(throw new RuntimeException(s"Impossible to parse $latitudeField in the event: ${row.mkString(",")}"))

    latitude match {
      case valueCast: String => valueCast
      case valueCast: Array[Byte] => new Predef.String(valueCast)
      case _ => latitude.toString
    }
  }

  private def getLongitude(row: Row): String = {
    val longitude = Try(row.get(schema.fieldIndex(longitudeField)))
      .getOrElse(throw new RuntimeException(s"Impossible to parse $latitudeField in the event: ${row.mkString(",")}"))

    longitude match {
      case valueCast: String => valueCast
      case valueCast: Array[Byte] => new Predef.String(valueCast)
      case _ => longitude.toString
    }
  }

  private def geoField(latitude: String, longitude: String): String = latitude + separator + longitude
} 
Example 99
Source File: JsonParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.json

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.transformation.WhenError.WhenError
import com.stratio.sparta.sdk.pipeline.transformation.{Parser, WhenError}
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import com.stratio.sparta.sdk.properties.models.PropertiesQueriesModel
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import scala.util.Try

class JsonParser(order: Integer,
                 inputField: Option[String],
                 outputFields: Seq[String],
                 schema: StructType,
                 properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) {

  val queriesModel = properties.getPropertiesQueries("queries")

  //scalastyle:off
  override def parse(row: Row): Seq[Row] = {
    val inputValue = Option(row.get(inputFieldIndex))
    val newData = Try {
      inputValue match {
        case Some(value) =>
          val valuesParsed = value match {
          case valueCast: Array[Byte] => JsonParser.jsonParse(new Predef.String(valueCast), queriesModel, whenErrorDo)
          case valueCast: String => JsonParser.jsonParse(valueCast, queriesModel, whenErrorDo)
          case _ => JsonParser.jsonParse(value.toString, queriesModel, whenErrorDo)
        }

          outputFields.map { outputField =>
            val outputSchemaValid = outputFieldsSchema.find(field => field.name == outputField)
            outputSchemaValid match {
              case Some(outSchema) =>
                valuesParsed.get(outSchema.name) match {
                  case Some(valueParsed) if valueParsed != null =>
                    parseToOutputType(outSchema, valueParsed)
                  case _ =>
                    returnWhenError(new IllegalStateException(
                      s"The values parsed don't contain the schema field: ${outSchema.name}"))
                }
              case None =>
                returnWhenError(new IllegalStateException(
                  s"Impossible to parse outputField: $outputField in the schema"))
            }
          }
        case None =>
          returnWhenError(new IllegalStateException(s"The input value is null or empty"))
      }
    }

    returnData(newData, removeInputField(row))
  }

  //scalastyle:on
}

object JsonParser {

  def jsonParse(jsonData: String,
                queriesModel: PropertiesQueriesModel,
                whenError: WhenError = WhenError.Null): Map[String, Any] = {
    val jsonPathExtractor = new JsonPathExtractor(jsonData, whenError == WhenError.Null)

    queriesModel.queries.map(queryModel => (queryModel.field, jsonPathExtractor.query(queryModel.query))).toMap
  }
} 
Example 100
Source File: XPathParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.xpath

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import com.stratio.sparta.sdk.properties.models.PropertiesQueriesModel
import kantan.xpath._
import kantan.xpath.ops._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import scala.util.Try

class XPathParser(order: Integer,
                  inputField: Option[String],
                  outputFields: Seq[String],
                  schema: StructType,
                  properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) {

  val queriesModel = properties.getPropertiesQueries("queries")

  //scalastyle:off
  override def parse(row: Row): Seq[Row] = {
    val inputValue = Option(row.get(inputFieldIndex))
    val newData = Try {
      inputValue match {
        case Some(value) =>
          val valuesParsed = value match {
            case valueCast: Array[Byte] => XPathParser.xPathParse(new Predef.String(valueCast), queriesModel)
            case valueCast: String => XPathParser.xPathParse(valueCast, queriesModel)
            case _ => XPathParser.xPathParse(value.toString, queriesModel)
          }

          outputFields.map { outputField =>
            val outputSchemaValid = outputFieldsSchema.find(field => field.name == outputField)
            outputSchemaValid match {
              case Some(outSchema) =>
                valuesParsed.get(outSchema.name) match {
                  case Some(valueParsed) =>
                    parseToOutputType(outSchema, valueParsed)
                  case None =>
                    returnWhenError(new IllegalStateException(
                      s"The values parsed not have the schema field: ${outSchema.name}"))
                }
              case None =>
                returnWhenError(new IllegalStateException(
                  s"Impossible to parse outputField: $outputField in the schema"))
            }
          }
        case None =>
          returnWhenError(new IllegalStateException(s"The input value is null or empty"))
      }
    }

    returnData(newData, removeInputField(row))
  }

  //scalastyle:on
}

object XPathParser {

  def xPathParse(xmlData: String, queriesModel: PropertiesQueriesModel): Map[String, Any] =
    queriesModel.queries.map(queryModel => (queryModel.field, parse[String](xmlData, queryModel.query))).toMap

  
  private def applyQuery[T: Compiler](source: String, query: String): XPathResult[T] =
    source.evalXPath[T](query)
} 
Example 101
Source File: MorphlinesParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.morphline

import java.io.{ByteArrayInputStream, Serializable => JSerializable}
import java.util.concurrent.ConcurrentHashMap

import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import com.typesafe.config.ConfigFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField, StructType}
import org.kitesdk.morphline.api.Record

import scala.collection.JavaConverters._
import scala.util.Try

class MorphlinesParser(order: Integer,
                       inputField: Option[String],
                       outputFields: Seq[String],
                       schema: StructType,
                       properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) {

  assert(inputField.isDefined, "It's necessary to define one inputField in the Morphline Transformation")

  private val config: String = properties.getString("morphline")

  override def parse(row: Row): Seq[Row] = {
    val inputValue = Option(row.get(inputFieldIndex))
    val newData = Try {
      inputValue match {
        case Some(s: String) =>
          if (s.isEmpty) returnWhenError(new IllegalStateException(s"Impossible to parse because value is empty"))
          else parseWithMorphline(new ByteArrayInputStream(s.getBytes("UTF-8")))
        case Some(b: Array[Byte]) =>
          if (b.length == 0)  returnWhenError(new IllegalStateException(s"Impossible to parse because value is empty"))
          else parseWithMorphline(new ByteArrayInputStream(b))
        case _ =>
          returnWhenError(new IllegalStateException(s"Impossible to parse because value is empty"))
      }
    }
    returnData(newData, removeInputFieldMorphline(row))
  }

   private def removeIndex(row: Row, inputFieldIndex: Int): Row =
    if (row.size < inputFieldIndex) row
    else Row.fromSeq(row.toSeq.take(inputFieldIndex) ++ row.toSeq.drop(inputFieldIndex + 1))


   private def removeInputFieldMorphline(row: Row): Row =
    if (inputFieldRemoved && inputField.isDefined) removeIndex(row, inputFieldIndex)
    else row

  private def parseWithMorphline(value: ByteArrayInputStream): Row = {
    val record = new Record()
    record.put(inputField.get, value)
    MorphlinesParser(order, config, outputFieldsSchema).process(record)
  }
}

object MorphlinesParser {

  private val instances = new ConcurrentHashMap[String, KiteMorphlineImpl].asScala

  def apply(order: Integer, config: String, outputFieldsSchema: Array[StructField]): KiteMorphlineImpl = {
    instances.get(config) match {
      case Some(kiteMorphlineImpl) =>
        kiteMorphlineImpl
      case None =>
        val kiteMorphlineImpl = KiteMorphlineImpl(ConfigFactory.parseString(config), outputFieldsSchema)
        instances.putIfAbsent(config, kiteMorphlineImpl)
        kiteMorphlineImpl
    }
  }
} 
Example 102
Source File: FilterParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.filter

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.filter.Filter
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.pipeline.schema.TypeOp._
import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

class FilterParser(order: Integer,
                   inputField: Option[String],
                   outputFields: Seq[String],
                   val schema: StructType,
                   properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) with Filter {

  def filterInput: Option[String] = properties.getString("filters", None)

  def defaultCastingFilterType: TypeOp = TypeOp.Any

  override def parse(row: Row): Seq[Row] =
    applyFilters(row) match {
      case Some(valuesFiltered) =>
        Seq(Row.fromSeq(removeInputField(row)))
      case None => Seq.empty[Row]
    }
} 
Example 103
Source File: CsvParser.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.csv

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.transformation.Parser
import com.stratio.sparta.sdk.properties.ValidatingPropertyMap._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import scala.util.Try

class CsvParser(order: Integer,
                 inputField: Option[String],
                 outputFields: Seq[String],
                 schema: StructType,
                 properties: Map[String, JSerializable])
  extends Parser(order, inputField, outputFields, schema, properties) {

  val fieldsModel = properties.getPropertiesFields("fields")
  val fieldsSeparator = Try(properties.getString("delimiter")).getOrElse(",")

  //scalastyle:off
  override def parse(row: Row): Seq[Row] = {
    val inputValue = Option(row.get(inputFieldIndex))
    val newData = Try {
      inputValue match {
        case Some(value) =>
          val valuesSplitted = {
            value match {
              case valueCast: Array[Byte] => new Predef.String(valueCast)
              case valueCast: String => valueCast
              case _ => value.toString
            }
          }.split(fieldsSeparator)

          if(valuesSplitted.length == fieldsModel.fields.length){
            val valuesParsed = fieldsModel.fields.map(_.name).zip(valuesSplitted).toMap
            outputFields.map { outputField =>
              val outputSchemaValid = outputFieldsSchema.find(field => field.name == outputField)
              outputSchemaValid match {
                case Some(outSchema) =>
                  valuesParsed.get(outSchema.name) match {
                    case Some(valueParsed) =>
                      parseToOutputType(outSchema, valueParsed)
                    case None =>
                      returnWhenError(new IllegalStateException(
                        s"The values parsed don't contain the schema field: ${outSchema.name}"))
                  }
                case None =>
                  returnWhenError(new IllegalStateException(
                    s"Impossible to parse outputField: $outputField in the schema"))
              }
            }
          } else returnWhenError(new IllegalStateException(s"The values splitted are greater or lower than the properties fields"))
        case None =>
          returnWhenError(new IllegalStateException(s"The input value is null or empty"))
      }
    }

    returnData(newData, removeInputField(row))
  }

  //scalastyle:on
} 
Example 104
Source File: CassandraOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.cassandra

import java.io.{Serializable => JSerializable}

import com.datastax.spark.connector.cql.CassandraConnector
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class CassandraOutputTest extends FlatSpec with Matchers with MockitoSugar with AnswerSugar {

  val s = "sum"
  val properties = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))

  "getSparkConfiguration" should "return a Seq with the configuration" in {
    val configuration = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))
    val cass = CassandraOutput.getSparkConfiguration(configuration)

    cass should be(List(("spark.cassandra.connection.host", "127.0.0.1"), ("spark.cassandra.connection.port", "9042")))
  }

  "getSparkConfiguration" should "return all cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("sparkProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(true)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(true)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }

  "getSparkConfiguration" should "not return cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("hadoopProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(false)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(false)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }
} 
Example 105
Source File: LastValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class LastValueOperatorTest extends WordSpec with Matchers {

  "LastValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new LastValueOperator("lastValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b"))
    }

    "associative process must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))

      val inputFields4 = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new LastValueOperator("lastValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 106
Source File: StddevOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class StddevOperatorTest extends WordSpec with Matchers {

  "Std dev operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new StddevOperator("stdev", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(
        Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }

    "processReduce distinct must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(
        Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }
  }
} 
Example 107
Source File: MedianOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MedianOperatorTest extends WordSpec with Matchers {

  "Median operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MedianOperator("median", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d))

      val inputFields3 = new MedianOperator("median", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5))

      val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5"))
    }
  }
} 
Example 108
Source File: ModeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ModeOperatorTest extends WordSpec with Matchers {

  "Mode operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new ModeOperator("mode", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new ModeOperator("mode", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new ModeOperator("mode", initSchema, Map())
      inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey")))

      val inputFields3 = new ModeOperator("mode", initSchema, Map())
      inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1")))

      val inputFields4 = new ModeOperator("mode", initSchema, Map())
      inputFields4.processReduce(Seq(
        Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4")))

      val inputFields5 = new ModeOperator("mode", initSchema, Map())
      inputFields5.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4")))

      val inputFields6 = new ModeOperator("mode", initSchema, Map())
      inputFields6.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5"))
      ) should be(Some(List("1", "2", "4")))
    }
  }
} 
Example 109
Source File: RangeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class RangeOperatorTest extends WordSpec with Matchers {

  "Range operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new RangeOperator("range", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }
  }
} 
Example 110
Source File: AccumulatorOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class AccumulatorOperatorTest extends WordSpec with Matchers {

  "Accumulator operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new AccumulatorOperator("accumulator", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1")))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b")))
    }

    "associative process must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))),
        (Operator.NewValuesKey, Some(Seq(2L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Seq("1", "2")))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(3))))
      inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d)))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(1))))
      inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1")))
    }
  }
} 
Example 111
Source File: FirstValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FirstValueOperatorTest extends WordSpec with Matchers {

  "FirstValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FirstValueOperator("firstValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a"))
    }

    "associative process must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(1)),
        (Operator.NewValuesKey, None))
      inputFields3.associativity(resultInput3) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 112
Source File: MeanAssociativeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanAssociativeOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanAssociativeOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "associative process must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput =
        Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0)))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))),
        (Operator.NewValuesKey, Some(Seq(1d))))
      inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5)))
    }
  }
} 
Example 113
Source File: MeanOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }
  }
} 
Example 114
Source File: OperatorEntityCountTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class OperatorEntityCountTest extends WordSpec with Matchers {

  "EntityCount" should {
    val props = Map(
      "inputField" -> "inputField".asInstanceOf[JSerializable],
      "split" -> ",".asInstanceOf[JSerializable])
    val schema = StructType(Seq(StructField("inputField", StringType)))
    val entityCount = new OperatorEntityCountMock("op1", schema, props)
    val inputFields = Row("hello,bye")

    "Return the associated precision name" in {
      val expected = Option(Seq("hello", "bye"))
      val result = entityCount.processMap(inputFields)
      result should be(expected)
    }

    "Return empty list" in {
      val expected = None
      val result = entityCount.processMap(Row())
      result should be(expected)
    }
  }
} 
Example 115
Source File: EntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class EntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new EntityCountOperator("entityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo")))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola")))
    }

    "associative process must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, Some(Seq("hola"))))
      inputFields2.associativity(resultInput2) should be(Some(Map()))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))))
      inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))
    }
  }
} 
Example 116
Source File: SageMakerProtobufWriter.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.protobuf

import java.io.ByteArrayOutputStream

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, NullWritable}
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types.StructType


  def write(row: Row): Unit = {
    val labelColumnName = options.getOrElse("labelColumnName", "label")
    val featuresColumnName = options.getOrElse("featuresColumnName", "features")

    val record = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Some(labelColumnName))
    record.writeTo(byteArrayOutputStream)

    recordWriter.write(NullWritable.get(), new BytesWritable(byteArrayOutputStream.toByteArray))
    byteArrayOutputStream.reset()
  }

  override def close(): Unit = {
    recordWriter.close(context)
  }
} 
Example 117
Source File: LibSVMResponseRowDeserializer.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers

import org.apache.spark.ml.linalg.{SparseVector, SQLDataTypes}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.transformation.{ContentTypes, ResponseRowDeserializer}


  override val accepts: String = ContentTypes.TEXT_LIBSVM

  private def parseLibSVMRow(record: String): Row = {
    val items = record.split(' ')
    val label = items.head.toDouble
    val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
      val entry = item.split(':')
      val index = entry(0).toInt - 1
      val value = entry(1).toDouble
      (index, value)
    }.unzip
    Row(label, new SparseVector(dim, indices.toArray, values.toArray))
  }

  override val schema: StructType = StructType(
    Array(
      StructField(labelColumnName, DoubleType, nullable = false),
      StructField(featuresColumnName, SQLDataTypes.VectorType, nullable = false)))
} 
Example 118
Source File: SchemaValidators.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.apache.spark.ml.linalg.SQLDataTypes
import org.apache.spark.sql.types.{DoubleType, StructType}

private[serializers] object SchemaValidators {
  def labeledSchemaValidator(schema: StructType,
                             labelColumnName: String,
                             featuresColumnName: String): Unit = {
    if (
      !schema.exists(f => f.name == labelColumnName && f.dataType == DoubleType) ||
      !schema.exists(f => f.name == featuresColumnName && f.dataType == SQLDataTypes.VectorType)) {
      throw new IllegalArgumentException(s"Expecting schema with DoubleType column with name " +
        s"$labelColumnName and Vector column with name $featuresColumnName. Got ${schema.toString}")
    }
  }

  def unlabeledSchemaValidator(schema: StructType, featuresColumnName: String): Unit = {
    if (!schema.exists(f => f.name == featuresColumnName &&
      f.dataType == SQLDataTypes.VectorType)) {
      throw new IllegalArgumentException(
        s"Expecting schema with Vector column with name" +
        s" $featuresColumnName. Got ${schema.toString}")
    }
  }
} 
Example 119
Source File: ProtobufRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.protobuf.ProtobufConverter

class ProtobufRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {

  val labelColumnName = "label"
  val featuresColumnName = "features"
  val schema = StructType(Array(StructField(labelColumnName, DoubleType), StructField(
    featuresColumnName, VectorType)))

  it should "serialize a dense vector" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new ProtobufRequestRowSerializer(Some(schema))
    val protobuf = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Option.empty)
    val serialized = rrs.serializeRow(row)
    val protobufIterator = ProtobufConverter.recordIOByteArrayToProtobufs(serialized)
    val protobufFromRecordIO = protobufIterator.next

    assert(!protobufIterator.hasNext)
    assert(protobuf.equals(protobufFromRecordIO))
  }

  it should "serialize a sparse vector" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new ProtobufRequestRowSerializer(Some(schema))
    val protobuf = ProtobufConverter.rowToProtobuf(row, featuresColumnName, Option.empty)
    val serialized = rrs.serializeRow(row)
    val protobufIterator = ProtobufConverter.recordIOByteArrayToProtobufs(serialized)
    val protobufFromRecordIO = protobufIterator.next

    assert(!protobufIterator.hasNext)
    assert(protobuf.equals(protobufFromRecordIO))
  }

  it should "fail to set schema on invalid features name" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    intercept[IllegalArgumentException] {
      val rrs = new ProtobufRequestRowSerializer(Some(schema), featuresColumnName = "doesNotExist")
    }
  }


  it should "fail on invalid types" in {
    val schemaWithInvalidFeaturesType = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", StringType, nullable = false)))
    intercept[RuntimeException] {
      new ProtobufRequestRowSerializer(Some(schemaWithInvalidFeaturesType))
    }
  }

  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    new ProtobufRequestRowSerializer(Some(validSchema))
  }
} 
Example 120
Source File: UnlabeledLibSVMRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class UnlabeledLibSVMRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {

  val schema = StructType(Array(StructField("features", SQLDataTypes.VectorType, nullable = false)))

  "UnlabeledLibSVMRequestRowSerializer" should "serialize sparse vector" in {
    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    val serialized = new String(rrs.serializeRow(row))
    assert ("0.0 1:-100.0 11:100.1\n" == serialized)
  }

  it should "serialize dense vector" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    val serialized = new String(rrs.serializeRow(row))
    assert("0.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "fail on invalid features column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs =
      new UnlabeledLibSVMRequestRowSerializer(featuresColumnName = "mangoes are not features")
    intercept[RuntimeException] {
      rrs.serializeRow(row)
    }
  }

  it should "fail on invalid features type" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row =
      new GenericRowWithSchema(values = Seq(1.0, "FEATURESSSSSZ!1!").toArray, schema = schema)
    val rrs = new UnlabeledLibSVMRequestRowSerializer()
    intercept[RuntimeException] {
      rrs.serializeRow(row)
    }
  }


  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("features", SQLDataTypes.VectorType, nullable = false)))

    val rrs = new UnlabeledLibSVMRequestRowSerializer(Some(validSchema))
  }

  it should "fail to validate incorrect schema" in {
    val invalidSchema = StructType(Array(
      StructField("features", StringType, nullable = false)))

    intercept[IllegalArgumentException] {
      new UnlabeledLibSVMRequestRowSerializer(Some(invalidSchema))
    }
  }
} 
Example 121
Source File: LibSVMRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest._
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.LibSVMResponseRowDeserializer

class LibSVMRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {
  val schema = new LibSVMResponseRowDeserializer(10).schema

  "LibSVMRequestRowSerializer" should "serialize sparse vector" in {

    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new LibSVMRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    assert ("1.0 1:-100.0 11:100.1\n" == serialized)
  }

  it should "serialize dense vector" in {

    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(1.0, vec).toArray, schema = schema)
    val rrs = new LibSVMRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    assert("1.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "ignore other columns" in {
    val schemaWithExtraColumns = StructType(Array(
      StructField("name", StringType, nullable = false),
      StructField("label", DoubleType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false),
        StructField("favorite activity", StringType, nullable = false)))

    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq("Elizabeth", 1.0, vec, "Crying").toArray,
      schema = schemaWithExtraColumns)

    val rrs = new LibSVMRequestRowSerializer(Some(schemaWithExtraColumns))
    val serialized = new String(rrs.serializeRow(row))
    assert("1.0 1:10.0 2:-100.0 3:2.0\n" == serialized)
  }

  it should "fail on invalid features column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schema), featuresColumnName = "i do not exist dear sir!")
    }
  }

  it should "fail on invalid label column name" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schema),
        labelColumnName = "Sir! I must protest! I do not exist!")
    }
  }

  it should "fail on invalid types" in {
    val schemaWithInvalidLabelType = StructType(Array(
      StructField("label", StringType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schemaWithInvalidLabelType))
    }
    val schemaWithInvalidFeaturesType = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", StringType, nullable = false)))
    intercept[RuntimeException] {
      new LibSVMRequestRowSerializer(Some(schemaWithInvalidFeaturesType))
    }
  }

  it should "validate correct schema" in {
    val validSchema = StructType(Array(
      StructField("label", DoubleType, nullable = false),
      StructField("features", SQLDataTypes.VectorType, nullable = false)))
    new LibSVMRequestRowSerializer(Some(validSchema))
  }
} 
Example 122
Source File: UnlabeledCSVRequestRowSerializerTests.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package unit.com.amazonaws.services.sagemaker.sparksdk.transformation.serializers

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.ml.linalg.{DenseVector, SparseVector, SQLDataTypes}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{StructField, StructType}

import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.UnlabeledCSVRequestRowSerializer

class UnlabeledCSVRequestRowSerializerTests extends FlatSpec with Matchers with MockitoSugar {
  val schema: StructType =
    StructType(Array(StructField("features", SQLDataTypes.VectorType, nullable = false)))

  it should "serialize sparse vector" in {

    val vec = new SparseVector(100, Seq[Int](0, 10).toArray, Seq[Double](-100.0, 100.1).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledCSVRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    val sparseString = "-100.0," + "0.0," * 9 + "100.1," + "0.0," * 88 + "0.0\n"
    assert (sparseString == serialized)
  }

  it should "serialize dense vector" in {
    val vec = new DenseVector(Seq(10.0, -100.0, 2.0).toArray)
    val row = new GenericRowWithSchema(values = Seq(vec).toArray, schema = schema)
    val rrs = new UnlabeledCSVRequestRowSerializer(Some(schema))
    val serialized = new String(rrs.serializeRow(row))
    assert("10.0,-100.0,2.0\n" == serialized)
  }
} 
Example 123
Source File: VectorSlicerExample.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.feature.VectorSlicer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
// $example off$
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object VectorSlicerExample {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("VectorSlicerExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // $example on$
    val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0)))

    val defaultAttr = NumericAttribute.defaultAttr
    val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName)
    val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])

    val dataRDD = sc.parallelize(data)
    val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField())))

    val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")

    slicer.setIndices(Array(1)).setNames(Array("f3"))
    // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3"))

    val output = slicer.transform(dataset)
    println(output.select("userFeatures", "features").first())
    // $example off$
    sc.stop()
  }
}
// scalastyle:on println 
Example 124
Source File: RelationCatalog.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType

import scala.collection.mutable.HashMap

class RelationCatalog extends Serializable {
  val directory = HashMap.empty[String, RelationInfo]

  def addRelation(name : String, schema : StructType) : Unit = {
    val relationInfo = new RelationInfo().setSchema(schema)
    directory.get(name) match {
      case Some(oldRelationInfo) =>
        // update rdd if already present.  Schema should not change
        oldRelationInfo.setRDD(relationInfo.getRDD())
      case None => directory.put(name, relationInfo)
    }
  }

  def setRDD(name : String, rdd : RDD[InternalRow]) : Unit = {
    directory.get(name) match {
      case Some(oldRelationInfo) => oldRelationInfo.setRDD(rdd)
      case None => directory.put(name, new RelationInfo().setRDD(rdd))
    }
  }

  def getRelationInfo(name : String) : RelationInfo = {
    if (directory.contains(name))
      directory(name)
    else
      null
  }

  def removeRDD(name : String) : Unit = {
    directory.remove(name)
  }

  def clear() : Unit = {
    directory.clear()
  }

  override def toString(): String = {
    val output = new StringBuilder()
    directory.iterator.foreach(f => output.append(f.toString()))
    output.toString()
  }
}

class RelationInfo() extends Serializable {
  private var schema : StructType = _
  private var rdd : RDD[InternalRow] = _

  def getSchema() : StructType = schema

  def setSchema(schema : StructType) : RelationInfo = {
    this.schema = schema
    this
  }

  def getRDD() : RDD[InternalRow] = rdd

  def setRDD(rdd : RDD[InternalRow]) : RelationInfo = {
    this.rdd = rdd
    this
  }

  override def toString() : String = {
    "schema: " + this.schema + (if (rdd != null) " RDD")
  }
} 
Example 125
Source File: HashSetRowIterator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.storage.set.hashset

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class ObjectHashSetRowIterator(set: ObjectHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    rawIter.next()
  }
}

class IntKeysHashSetRowIterator(set: IntKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    bufferHolder.reset()
    rowWriter.initialize(bufferHolder, 1)
    rowWriter.write(0, rawIter.next())

    uRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize())
    uRow
  }
}

class LongKeysHashSetRowIterator(set: LongKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val numFields = set.schemaInfo.arity
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    bufferHolder.reset()
    rowWriter.initialize(bufferHolder, numFields)

    val value = rawIter.nextLong()
    if (numFields == 2) {
      rowWriter.write(0, (value >> 32).toInt)
      rowWriter.write(1, value.toInt)
    } else {
      rowWriter.write(0, value)
    }
    uRow.pointTo(bufferHolder.buffer, numFields, bufferHolder.totalSize())
    uRow
  }
}

object HashSetRowIterator {
  def create(set: HashSet): Iterator[InternalRow] = {
    set match {
      //case set: UnsafeFixedWidthSet => set.iterator().asScala
      case set: IntKeysHashSet => new IntKeysHashSetRowIterator(set)
      case set: LongKeysHashSet => new LongKeysHashSetRowIterator(set)
      case set: ObjectHashSet => new ObjectHashSetRowIterator(set)
    }
  }
} 
Example 126
Source File: HashingTF.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}


  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

  override def transform(dataset: DataFrame): DataFrame = {
    val outputSchema = transformSchema(dataset.schema)
    val hashingTF = new feature.HashingTF($(numFeatures))
    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
    val metadata = outputSchema($(outputCol)).metadata
    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    require(inputType.isInstanceOf[ArrayType],
      s"The input column must be ArrayType, but got $inputType.")
    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
  }

  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}

@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {

  @Since("1.6.0")
  override def load(path: String): HashingTF = super.load(path)
} 
Example 127
Source File: SQLTransformer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType


  @Since("1.6.0")
  def getStatement: String = $(statement)

  private val tableIdentifier: String = "__THIS__"

  @Since("1.6.0")
  override def transform(dataset: DataFrame): DataFrame = {
    val tableName = Identifiable.randomUID(uid)
    dataset.registerTempTable(tableName)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    val outputDF = dataset.sqlContext.sql(realStatement)
    outputDF
  }

  @Since("1.6.0")
  override def transformSchema(schema: StructType): StructType = {
    val sc = SparkContext.getOrCreate()
    val sqlContext = SQLContext.getOrCreate(sc)
    val dummyRDD = sc.parallelize(Seq(Row.empty))
    val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
    dummyDF.registerTempTable(tableIdentifier)
    val outputSchema = sqlContext.sql($(statement)).schema
    outputSchema
  }

  @Since("1.6.0")
  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}

@Since("1.6.0")
object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {

  @Since("1.6.0")
  override def load(path: String): SQLTransformer = super.load(path)
} 
Example 128
Source File: Binarizer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}


  def setOutputCol(value: String): this.type = set(outputCol, value)

  override def transform(dataset: DataFrame): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val td = $(threshold)
    val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val outputColName = $(outputCol)
    val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
    dataset.select(col("*"),
      binarizer(col($(inputCol))).as(outputColName, metadata))
  }

  override def transformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

    val inputFields = schema.fields
    val outputColName = $(outputCol)

    require(inputFields.forall(_.name != outputColName),
      s"Output column $outputColName already exists.")

    val attr = BinaryAttribute.defaultAttr.withName(outputColName)
    val outputFields = inputFields :+ attr.toStructField()
    StructType(outputFields)
  }

  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
} 
Example 129
Source File: LibSVMRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.source.libsvm

import com.google.common.base.Objects

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}


@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {

  @Since("1.6.0")
  override def shortName(): String = "libsvm"

  @Since("1.6.0")
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
    : BaseRelation = {
    val path = parameters.getOrElse("path",
      throw new IllegalArgumentException("'path' must be specified"))
    val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
    val vectorType = parameters.getOrElse("vectorType", "sparse")
    new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
  }
} 
Example 130
Source File: OrcFileOperator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.orc

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader}
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector

import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.HiveMetastoreTypes
import org.apache.spark.sql.types.StructType

private[orc] object OrcFileOperator extends Logging {
  
  def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = {
    def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = {
      reader.getObjectInspector match {
        case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 =>
          logInfo(
            s"ORC file $path has empty schema, it probably contains no rows. " +
              "Trying to read another ORC file to figure out the schema.")
          false
        case _ => true
      }
    }

    val conf = config.getOrElse(new Configuration)
    val fs = {
      val hdfsPath = new Path(basePath)
      hdfsPath.getFileSystem(conf)
    }

    listOrcFiles(basePath, conf).iterator.map { path =>
      path -> OrcFile.createReader(fs, path)
    }.collectFirst {
      case (path, reader) if isWithNonEmptySchema(path, reader) => reader
    }
  }

  def readSchema(path: String, conf: Option[Configuration]): StructType = {
    val reader = getFileReader(path, conf).getOrElse {
      throw new AnalysisException(
        s"Failed to discover schema from ORC files stored in $path. " +
          "Probably there are either no ORC files or only empty ORC files.")
    }
    val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector]
    val schema = readerInspector.getTypeName
    logDebug(s"Reading schema from file $path, got Hive schema string: $schema")
    HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType]
  }

  def getObjectInspector(
      path: String, conf: Option[Configuration]): Option[StructObjectInspector] = {
    getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector])
  }

  def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
    val origPath = new Path(pathStr)
    val fs = origPath.getFileSystem(conf)
    val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
    val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
      .filterNot(_.isDir)
      .map(_.getPath)
      .filterNot(_.getName.startsWith("_"))
      .filterNot(_.getName.startsWith("."))

    if (paths == null || paths.isEmpty) {
      throw new IllegalArgumentException(
        s"orcFileOperator: path $path does not have valid orc files matching the pattern")
    }

    paths
  }
} 
Example 131
Source File: LocalRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
import org.apache.spark.sql.types.{StructField, StructType}

object LocalRelation {
  def apply(output: Attribute*): LocalRelation = new LocalRelation(output)

  def apply(output1: StructField, output: StructField*): LocalRelation = {
    new LocalRelation(StructType(output1 +: output).toAttributes)
  }

  def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }

  def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = {
    val schema = StructType.fromAttributes(output)
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow]))
  }
}

case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
  extends LeafNode with analysis.MultiInstanceRelation {

  
  override final def newInstance(): this.type = {
    LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
  }

  override protected def stringArgs = Iterator(output)

  override def sameResult(plan: LogicalPlan): Boolean = plan match {
    case LocalRelation(otherOutput, otherData) =>
      otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
    case _ => false
  }

  override lazy val statistics =
    Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length)
} 
Example 132
Source File: JDBCRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}


  def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
    if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))

    val numPartitions = partitioning.numPartitions
    val column = partitioning.column
    if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
    // Overflow and silliness can happen if you subtract then divide.
    // Here we get a little roundoff, but that's (hopefully) OK.
    val stride: Long = (partitioning.upperBound / numPartitions
                      - partitioning.lowerBound / numPartitions)
    var i: Int = 0
    var currentValue: Long = partitioning.lowerBound
    var ans = new ArrayBuffer[Partition]()
    while (i < numPartitions) {
      val lowerBound = if (i != 0) s"$column >= $currentValue" else null
      currentValue += stride
      val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
      val whereClause =
        if (upperBound == null) {
          lowerBound
        } else if (lowerBound == null) {
          upperBound
        } else {
          s"$lowerBound AND $upperBound"
        }
      ans += JDBCPartition(whereClause, i)
      i = i + 1
    }
    ans.toArray
  }
}

private[sql] case class JDBCRelation(
    url: String,
    table: String,
    parts: Array[Partition],
    properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedFilteredScan
  with InsertableRelation {

  override val needConversion: Boolean = false

  override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    JDBCRDD.scanTable(
      sqlContext.sparkContext,
      schema,
      url,
      properties,
      table,
      requiredColumns,
      filters,
      parts).asInstanceOf[RDD[Row]]
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    data.write
      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
      .jdbc(url, table, properties)
  }
} 
Example 133
Source File: Queryable.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType


  private[sql] def formatString (
      rows: Seq[Seq[String]],
      numRows: Int,
      hasMoreData : Boolean,
      truncate: Boolean = true): String = {
    val sb = new StringBuilder
    val numCols = schema.fieldNames.length

    // Initialise the width of each column to a minimum value of '3'
    val colWidths = Array.fill(numCols)(3)

    // Compute the width of each column
    for (row <- rows) {
      for ((cell, i) <- row.zipWithIndex) {
        colWidths(i) = math.max(colWidths(i), cell.length)
      }
    }

    // Create SeparateLine
    val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()

    // column names
    rows.head.zipWithIndex.map { case (cell, i) =>
      if (truncate) {
        StringUtils.leftPad(cell, colWidths(i))
      } else {
        StringUtils.rightPad(cell, colWidths(i))
      }
    }.addString(sb, "|", "|", "|\n")

    sb.append(sep)

    // data
    rows.tail.map {
      _.zipWithIndex.map { case (cell, i) =>
        if (truncate) {
          StringUtils.leftPad(cell.toString, colWidths(i))
        } else {
          StringUtils.rightPad(cell.toString, colWidths(i))
        }
      }.addString(sb, "|", "|", "|\n")
    }

    sb.append(sep)

    // For Data that has more than "numRows" records
    if (hasMoreData) {
      val rowsString = if (numRows == 1) "row" else "rows"
      sb.append(s"only showing top $numRows $rowsString\n")
    }

    sb.toString()
  }
} 
Example 134
Source File: SparkSqlSerializer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import java.nio.ByteBuffer
import java.util.{HashMap => JavaHashMap}

import scala.reflect.ClassTag
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.types.{Decimal, StructField, StructType}
import org.apache.spark.util.MutablePair
import org.apache.spark.{SparkConf, SparkEnv}


//private[sql]
class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
  override def newKryo(): Kryo = {
    val kryo = super.newKryo()
    kryo.setRegistrationRequired(false)
    kryo.register(classOf[MutablePair[_, _]])
    kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
    kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
    kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
    kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
    kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)

    kryo.register(classOf[Decimal])
    kryo.register(classOf[JavaHashMap[_, _]])

    // APS
    kryo.register(classOf[StructType])
    kryo.register(classOf[StructField])

    kryo.setReferences(false)
    kryo
  }
}

private[execution] class KryoResourcePool(size: Int)
  extends ResourcePool[SerializerInstance](size) {

  val ser: SparkSqlSerializer = {
    val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
    new SparkSqlSerializer(sparkConf)
  }

  def newInstance(): SerializerInstance = ser.newInstance()
}

//private[sql]
object SparkSqlSerializer {
  @transient lazy val resourcePool = new KryoResourcePool(30)

  private[this] def acquireRelease[O](fn: SerializerInstance => O): O = {
    val kryo = resourcePool.borrow
    try {
      fn(kryo)
    } finally {
      resourcePool.release(kryo)
    }
  }

  def serialize[T: ClassTag](o: T): Array[Byte] =
    acquireRelease { k =>
      k.serialize(o).array()
    }

  def deserialize[T: ClassTag](bytes: Array[Byte]): T =
    acquireRelease { k =>
      k.deserialize[T](ByteBuffer.wrap(bytes))
    }
}

private[sql] class JavaBigDecimalSerializer extends Serializer[java.math.BigDecimal] {
  def write(kryo: Kryo, output: Output, bd: java.math.BigDecimal) {
    // TODO: There are probably more efficient representations than strings...
    output.writeString(bd.toString)
  }

  def read(kryo: Kryo, input: Input, tpe: Class[java.math.BigDecimal]): java.math.BigDecimal = {
    new java.math.BigDecimal(input.readString())
  }
}

private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
  def write(kryo: Kryo, output: Output, bd: BigDecimal) {
    // TODO: There are probably more efficient representations than strings...
    output.writeString(bd.toString)
  }

  def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = {
    new java.math.BigDecimal(input.readString())
  }
} 
Example 135
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 136
Source File: TextSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.util.Utils


class TextSuite extends QueryTest with SharedSQLContext {

  test("reading text file") {
    verifyFrame(sqlContext.read.format("text").load(testFile))
  }

  test("SQLContext.read.text() API") {
    verifyFrame(sqlContext.read.text(testFile))
  }

  test("SPARK-12562 verify write.text() can handle column name beyond `value`") {
    val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")

    val tempFile = Utils.createTempDir()
    tempFile.delete()
    df.write.text(tempFile.getCanonicalPath)
    verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))

    Utils.deleteRecursively(tempFile)
  }

  test("error handling for invalid schema") {
    val tempFile = Utils.createTempDir()
    tempFile.delete()

    val df = sqlContext.range(2)
    intercept[AnalysisException] {
      df.write.text(tempFile.getCanonicalPath)
    }

    intercept[AnalysisException] {
      sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
    }
  }

  private def testFile: String = {
    Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
  }

  
  private def verifyFrame(df: DataFrame): Unit = {
    // schema
    assert(df.schema == new StructType().add("value", StringType))

    // verify content
    val data = df.collect()
    assert(data(0) == Row("This is a test file for the text data source"))
    assert(data(1) == Row("1+1"))
    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
    // scalastyle:off
    assert(data(2) == Row("数据砖头"))
    // scalastyle:on
    assert(data(3) == Row("\"doh\""))
    assert(data.length == 4)
  }
} 
Example 137
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 138
Source File: ListTablesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.catalyst.TableIdentifier

class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
  }

  test("get all tables") {
    checkAnswer(
      sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
    assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("getting all Tables with a database name has no impact on returned table names") {
    checkAnswer(
      sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
    assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql(
            "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        sqlContext.dropTempTable("tables")
    }
  }
} 
Example 139
Source File: DDLSourceLoadSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      caseInsensitiveContext.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    caseInsensitiveContext.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without HiveContext") {
    intercept[ClassNotFoundException] {
      caseInsensitiveContext.read.format("orc").load()
    }
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 140
Source File: ScalaSparkSQLBySchema.scala    From learning-spark   with Apache License 2.0 5 votes vote down vote up
package com.javachen.spark.examples.sparksql

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.{SparkConf, SparkContext}

object ScalaSparkSQLBySchema {

  def main(args: Array[String]) {
    val sc = new SparkContext(new SparkConf().setAppName("ScalaSparkSQL"))
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)

    // Create an RDD
    val people = sc.textFile("people.txt")

    // The schema is encoded in a string
    val schemaString = "name age"

    // Import Spark SQL data types and Row.
    import org.apache.spark.sql._

    // Generate the schema based on the string of schema
    val schema =
      StructType(
        schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))

    // Convert records of the RDD (people) to Rows.
    val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))

    // Apply the schema to the RDD.
    val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema)

    // Register the DataFrames as a table.
    peopleDataFrame.registerTempTable("people")

    // SQL statements can be run by using the sql methods provided by sqlContext.
    val results = sqlContext.sql("SELECT name FROM people")

    // The results of SQL queries are DataFrames and support all the normal RDD operations.
    // The columns of a row in the result can be accessed by ordinal.
    results.map(t => "Name: " + t(0)).collect().foreach(println)
  }
} 
Example 141
Source File: DefaultSource.scala    From Spark-MongoDB   with Apache License 2.0 5 votes vote down vote up
package com.stratio.datasource.mongodb

import com.stratio.datasource.mongodb.config.MongodbConfigBuilder
import com.stratio.datasource.mongodb.config.MongodbConfig._
import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}


class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider{

  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String]): BaseRelation = {

    new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build())(sqlContext)

  }

  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String],
                               schema: StructType): BaseRelation = {

    new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build(), Some(schema))(sqlContext)

  }

  override def createRelation(
                               sqlContext: SQLContext,
                               mode: SaveMode,
                               parameters: Map[String, String],
                               data: DataFrame): BaseRelation = {

    val mongodbRelation = new MongodbRelation(
      MongodbConfigBuilder(parseParameters(parameters)).build(), Some(data.schema))(sqlContext)

    mode match{
      case Append         => mongodbRelation.insert(data, overwrite = false)
      case Overwrite      => mongodbRelation.insert(data, overwrite = true)
      case ErrorIfExists  => if(mongodbRelation.isEmptyCollection) mongodbRelation.insert(data, overwrite = false)
      else throw new UnsupportedOperationException("Writing in a non-empty collection.")
      case Ignore         => if(mongodbRelation.isEmptyCollection) mongodbRelation.insert(data, overwrite = false)
    }

    mongodbRelation
  }

} 
Example 142
Source File: LOFSuite.scala    From spark-lof   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.outlier

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.functions._

object LOFSuite {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("LOFExample")
      .master("local[4]")
      .getOrCreate()

    val schema = new StructType(Array(
      new StructField("col1", DataTypes.DoubleType),
      new StructField("col2", DataTypes.DoubleType)))
    val df = spark.read.schema(schema).csv("data/outlier.csv")

    val assembler = new VectorAssembler()
      .setInputCols(df.columns)
      .setOutputCol("features")
    val data = assembler.transform(df).repartition(4)

    val startTime = System.currentTimeMillis()
    val result = new LOF()
      .setMinPts(5)
      .transform(data)
    val endTime = System.currentTimeMillis()
    result.count()

    // Outliers have much higher LOF value than normal data
    result.sort(desc(LOF.lof)).head(10).foreach { row =>
      println(row.get(0) + " | " + row.get(1) + " | " + row.get(2))
    }
    println("Total time = " + (endTime - startTime) / 1000.0 + "s")
  }
} 
Example 143
Source File: FixedwidthRelation.scala    From spark-fixedwidth   with Apache License 2.0 5 votes vote down vote up
package com.quartethealth.spark.fixedwidth

import com.databricks.spark.csv.CsvRelation
import com.databricks.spark.csv.readers.{BulkReader, LineReader}
import com.quartethealth.spark.fixedwidth.readers.{BulkFixedwidthReader, LineFixedwidthReader}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType

class FixedwidthRelation protected[spark] (
    baseRDD: () => RDD[String],
    fixedWidths: Array[Int],
    location: Option[String],
    useHeader: Boolean,
    parseMode: String,
    comment: Character,
    ignoreLeadingWhiteSpace: Boolean,
    ignoreTrailingWhiteSpace: Boolean,
    treatEmptyValuesAsNulls: Boolean,
    userSchema: StructType,
    inferSchema: Boolean,
    codec: String = null,
    nullValue: String = "")(@transient override val sqlContext: SQLContext)
  extends CsvRelation(
    baseRDD,
    location,
    useHeader,
    delimiter = '\0',
    quote = null,
    escape = null,
    comment = comment,
    parseMode = parseMode,
    parserLib = "UNIVOCITY",
    ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
    ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
    treatEmptyValuesAsNulls = treatEmptyValuesAsNulls,
    userSchema = userSchema,
    inferCsvSchema = true,
    codec = codec)(sqlContext) {

  protected override def getLineReader(): LineReader = {
    val commentChar: Char = if (comment == null) '\0' else comment
    new LineFixedwidthReader(fixedWidths, commentMarker = commentChar,
      ignoreLeadingSpace = ignoreLeadingWhiteSpace,
      ignoreTrailingSpace = ignoreTrailingWhiteSpace)
  }

  protected override def getBulkReader(header: Seq[String], iter: Iterator[String],
      split: Int): BulkReader = {
    val commentChar: Char = if (comment == null) '\0' else comment
    new BulkFixedwidthReader(iter, split, fixedWidths,
      headers = header, commentMarker = commentChar,
      ignoreLeadingSpace = ignoreLeadingWhiteSpace,
      ignoreTrailingSpace = ignoreTrailingWhiteSpace)
  }
} 
Example 144
Source File: package.scala    From spark-fixedwidth   with Apache License 2.0 5 votes vote down vote up
package com.quartethealth.spark

import com.databricks.spark.csv.util.TextFile
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext}

package object fixedwidth {
  implicit class FixedwidthContext(sqlContext: SQLContext) extends Serializable {
    def fixedFile(
        filePath: String,
        fixedWidths: Array[Int],
        schema: StructType = null,
        useHeader: Boolean = true,
        mode: String = "PERMISSIVE",
        comment: Character = null,
        ignoreLeadingWhiteSpace: Boolean = true,
        ignoreTrailingWhiteSpace: Boolean = true,
        charset: String = TextFile.DEFAULT_CHARSET.name(),
        inferSchema: Boolean = false): DataFrame = {
      val fixedwidthRelation = new FixedwidthRelation(
        () => TextFile.withCharset(sqlContext.sparkContext, filePath, charset),
        location = Some(filePath),
        useHeader = useHeader,
        comment = comment,
        parseMode = mode,
        fixedWidths = fixedWidths,
        ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
        ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
        userSchema = schema,
        inferSchema = inferSchema,
        treatEmptyValuesAsNulls = false)(sqlContext)
      sqlContext.baseRelationToDataFrame(fixedwidthRelation)
    }
  }
} 
Example 145
Source File: MetastoreIndexSuite.scala    From parquet-index   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

import com.github.lightcopy.testutil.UnitTestSuite
import com.github.lightcopy.testutil.implicits._

// Test catalog to check internal methods
private[datasources] class TestIndex extends MetastoreIndex {
  private var internalIndexFilters: Seq[Filter] = Nil
  override def tablePath(): Path = ???
  override def partitionSchema: StructType = ???
  override def indexSchema: StructType = ???
  override def dataSchema: StructType = ???
  override def setIndexFilters(filters: Seq[Filter]) = {
    internalIndexFilters = filters
  }
  override def indexFilters: Seq[Filter] = internalIndexFilters
  override def listFilesWithIndexSupport(
      partitionFilters: Seq[Expression],
      dataFilters: Seq[Expression],
      indexFilters: Seq[Filter]): Seq[PartitionDirectory] = ???
  override def inputFiles: Array[String] = ???
  override def sizeInBytes: Long = ???
}

class MetastoreIndexSuite extends UnitTestSuite {
  test("provide sequence of path based on table path") {
    val catalog = new TestIndex() {
      override def tablePath(): Path = new Path("test")
    }

    catalog.rootPaths should be (Seq(new Path("test")))
  }

  test("when using listFiles directly supply empty index filter") {
    var indexSeq: Seq[Filter] = null
    var filterSeq: Seq[Expression] = null
    val catalog = new TestIndex() {
      override def listFilesWithIndexSupport(
          partitionFilters: Seq[Expression],
          dataFilters: Seq[Expression],
          indexFilters: Seq[Filter]): Seq[PartitionDirectory] = {
        indexSeq = indexFilters
        filterSeq = partitionFilters
        Seq.empty
      }
    }

    catalog.listFiles(Seq.empty, Seq.empty)
    indexSeq should be (Nil)
    filterSeq should be (Nil)
  }

  test("refresh should be no-op by default") {
    val catalog = new TestIndex()
    catalog.refresh()
  }
} 
Example 146
Source File: InfinispanRelation.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
import org.infinispan.client.hotrod.marshall.MarshallerUtil
import org.infinispan.spark.config.ConnectorConfiguration
import org.infinispan.spark.rdd.InfinispanRDD


class InfinispanRelation(context: SQLContext, val parameters: Map[String, String])
  extends BaseRelation with PrunedFilteredScan with Serializable {

   override def sqlContext: SQLContext = context

   lazy val props: ConnectorConfiguration = ConnectorConfiguration(parameters)

   val clazz = {
      val protoEntities = props.getProtoEntities
      val targetEntity = Option(props.getTargetEntity) match {
         case Some(p) => p
         case None => if (protoEntities.nonEmpty) protoEntities.head
         else
            throw new IllegalArgumentException(s"No target entity nor annotated protobuf entities found, check the configuration")
      }
      targetEntity
   }

   @transient lazy val mapper = ObjectMapper.forBean(schema, clazz)

   override def schema: StructType = SchemaProvider.fromJavaBean(clazz)

   override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
      val rdd: InfinispanRDD[AnyRef, AnyRef] = new InfinispanRDD(context.sparkContext, props)
      val serCtx = MarshallerUtil.getSerializationContext(rdd.remoteCache.getRemoteCacheManager)
      val message = serCtx.getMarshaller(clazz).getTypeName
      val projections = toIckle(requiredColumns)
      val predicates = toIckle(filters)

      val select = if (projections.nonEmpty) s"SELECT $projections" else ""
      val from = s"FROM $message"
      val where = if (predicates.nonEmpty) s"WHERE $predicates" else ""

      val query = s"$select $from $where"

      rdd.filterByQuery[AnyRef](query.trim).values.map(mapper(_, requiredColumns))
   }

   def toIckle(columns: Array[String]): String = columns.mkString(",")

   def toIckle(filters: Array[Filter]): String = filters.map(ToIckle).mkString(" AND ")

   private def ToIckle(f: Filter): String = {
      f match {
         case StringEndsWith(a, v) => s"$a LIKE '%$v'"
         case StringContains(a, _) => s"$a LIKE '%$a%'"
         case StringStartsWith(a, v) => s"$a LIKE '$v%'"
         case EqualTo(a, v) => s"$a = '$v'"
         case GreaterThan(a, v) => s"$a > $v"
         case GreaterThanOrEqual(a, v) => s"$a >= $v"
         case LessThan(a, v) => s"$a < $v"
         case LessThanOrEqual(a, v) => s"$a <= $v"
         case IsNull(a) => s"$a is null"
         case IsNotNull(a) => s"$a is not null"
         case In(a, vs) => s"$a IN (${vs.map(v => s"'$v'").mkString(",")})"
         case Not(filter) => s"NOT ${ToIckle(filter)}"
         case And(leftFilter, rightFilter) => s"${ToIckle(leftFilter)} AND ${ToIckle(rightFilter)}"
         case Or(leftFilter, rightFilter) => s"${ToIckle(leftFilter)} OR ${ToIckle(rightFilter)}"
      }
   }
} 
Example 147
Source File: ObjectMapper.scala    From infinispan-spark   with Apache License 2.0 5 votes vote down vote up
package org.infinispan.spark.sql

import java.beans.Introspector

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRowWithSchema}
import org.apache.spark.sql.types.StructType


object ObjectMapper {

   def forBean(schema: StructType, beanClass: Class[_]): (AnyRef, Array[String]) => Row = {
      val beanInfo = Introspector.getBeanInfo(beanClass)
      val attrs = schema.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
      val extractors = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
      val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
         (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
      }
      (from: Any, columns: Array[String]) => {
         if (columns.nonEmpty) {
            from match {
               case _: Array[_] => new GenericRowWithSchema(from.asInstanceOf[Array[Any]], schema)
               case f: Any =>
                  val rowSchema = StructType(Array(schema(columns.head)))
                  new GenericRowWithSchema(Array(f), rowSchema)
            }
         } else {
            new GenericRowWithSchema(methodsToConverts.map { case (e, convert) =>
               val invoke: AnyRef = e.invoke(from)
               convert(invoke)
            }, schema)

         }
      }
   }

} 
Example 148
Source File: CloudPartitionTest.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


abstract class CloudPartitionTest extends AbstractCloudRelationTest {

import testImplicits._

  ctest(
    "save-findClass-partitioned-part-columns-in-data",
    "Save sets of files in explicitly set up partition tree; read") {
    withTempPathDir("part-columns", None) { path =>
      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(path, s"p1=$p1/p2=$p2")
        val df = sparkContext
          .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
          .toDF("a", "b", "p1")

         df.write
          .format(dataSourceName)
          .mode(SaveMode.ErrorIfExists)
          .save(partitionDir.toString)
        // each of these directories as its own success file; there is
        // none at the root
        resolveSuccessFile(partitionDir, true)
      }

      val dataSchemaWithPartition =
        StructType(
          dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

      checkQueries(
        spark.read.options(Map(
          "path" -> path.toString,
          "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName)
          .load())
    }
  }
} 
Example 149
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = "NA12878.multichrom.md"
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
    List(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
    )
  )
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    spark
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
    spark.sql(
      s"""
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
    spark.sql(
      s"""
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )
         |
      """.stripMargin)

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
        "null"
      else
        map.map({
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")
    }

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)
  }

} 
Example 150
Source File: A_1_BasicOperation.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.structured_streaming

import java.sql.Timestamp

import org.apache.spark.sql.types.{BooleanType, StringType, StructType, TimestampType}
import org.apache.spark.sql.{Dataset, SparkSession}

object A_1_BasicOperation {

  //DateTime要使用Timestamp  case类必须使用java.sql。在catalyst中作为TimestampType调用的时间戳
  case class DeviceData(device: String, deviceType: String, signal: Double, time: Timestamp)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName(A_1_BasicOperation.getClass.getName)
      .master("local")
      .getOrCreate()
    val timeStructType = new StructType().add("device", StringType)
      .add("deviceType", StringType)
      .add("signal", BooleanType)
      .add("time", TimestampType)

    val dataFrame = spark.read.json("src/main/resources/sparkresource/device.json")
    import spark.implicits._
    val ds: Dataset[DeviceData] = dataFrame.as[DeviceData]

    //使用无类型方式查询,类sql
    dataFrame.select("device").where("signal>10").show()
    //使用有类型方式进行查询
    ds.filter(_.signal > 10).map(_.device).show()

    //使用无类型方式进行groupBy,并进行统计
    dataFrame.groupBy("deviceType").count().show()


    import org.apache.spark.sql.expressions.scalalang.typed
    //使用有类型方式进行 计算每种类型的设备的平均信号值
    ds.groupByKey(_.deviceType).agg(typed.avg(_.signal)).show()

    //也可以使用创建临时视图的形式,使用sql语句进行查询
    dataFrame.createOrReplaceTempView("device")
    spark.sql("select * from device").show()

    //可以使用isStreaming来判断是否有流数据
    println(dataFrame.isStreaming)
  }
} 
Example 151
Source File: A_1_DataFrameTest.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.sql

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}


      //定义一个schema
    val schemaString = "name age"
    val fields = schemaString.split(" ")
      .map(filedName => StructField(filedName, StringType, nullable = true))
    val structType = StructType(fields)

    val personRDD = sparkSession.sparkContext.textFile("src/main/resources/sparkresource/people.txt")
      .map(_.split(","))
      //将RDD转换为行
      .map(attr => Row(attr(0), attr(1).trim))
    //将schema应用于RDD,并创建df
    sparkSession.createDataFrame(personRDD,structType).createOrReplaceTempView("people1")
    val dataFrameBySchema = sparkSession.sql("select name,age from people1 where age > 19 ")
    dataFrameBySchema.show()
  }

} 
Example 152
Source File: FileSystemDataLink.scala    From lighthouse   with Apache License 2.0 5 votes vote down vote up
package be.dataminded.lighthouse.datalake

import be.dataminded.lighthouse.spark.{Orc, SparkFileFormat}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SaveMode}


class FileSystemDataLink(
    val path: LazyConfig[String],
    format: SparkFileFormat = Orc,
    saveMode: SaveMode = SaveMode.Overwrite,
    partitionedBy: List[String] = List.empty,
    options: Map[String, String] = Map.empty,
    schema: Option[StructType] = None
) extends PathBasedDataLink {

  override def doRead(path: String): DataFrame = {
    schema match {
      case Some(s) => spark.read.format(format.toString).options(options).schema(s).load(path)
      case None    => spark.read.format(format.toString).options(options).load(path)
    }
  }

  override def doWrite[T](dataset: Dataset[T], path: String): Unit = {
    dataset.write
      .format(format.toString)
      .partitionBy(partitionedBy: _*)
      .options(options)
      .mode(saveMode)
      .save(path)
  }
} 
Example 153
Source File: SparkAvroDecoder.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.avro

import org.apache.log4j.Logger

import java.io.ByteArrayOutputStream

import scala.reflect.runtime.universe._

import org.apache.avro.generic.{ GenericDatumReader, GenericDatumWriter, GenericRecord }
import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.spark.sql.{ Dataset, Encoder, Row }
import org.apache.spark.sql.catalyst.encoders.{ encoderFor, ExpressionEncoder, RowEncoder }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
import org.apache.avro.Schema

import cloudflow.spark.sql.SQLImplicits._

case class EncodedKV(key: String, value: Array[Byte])

case class SparkAvroDecoder[T: Encoder: TypeTag](avroSchema: String) {

  val encoder: Encoder[T]                           = implicitly[Encoder[T]]
  val sqlSchema: StructType                         = encoder.schema
  val encoderForDataColumns: ExpressionEncoder[Row] = RowEncoder(sqlSchema)
  @transient lazy val _avroSchema                   = new Schema.Parser().parse(avroSchema)
  @transient lazy val rowConverter                  = SchemaConverters.createConverterToSQL(_avroSchema, sqlSchema)
  @transient lazy val datumReader                   = new GenericDatumReader[GenericRecord](_avroSchema)
  @transient lazy val decoder                       = DecoderFactory.get
  def decode(bytes: Array[Byte]): Row = {
    val binaryDecoder = decoder.binaryDecoder(bytes, null)
    val record        = datumReader.read(null, binaryDecoder)
    rowConverter(record).asInstanceOf[GenericRow]
  }

}


case class SparkAvroEncoder[T: Encoder: TypeTag](avroSchema: String) {

  @transient lazy val log = Logger.getLogger(getClass.getName)

  val BufferSize = 5 * 1024 // 5 Kb

  val encoder                     = implicitly[Encoder[T]]
  val sqlSchema                   = encoder.schema
  @transient lazy val _avroSchema = new Schema.Parser().parse(avroSchema)

  val recordName                = "topLevelRecord" // ???
  val recordNamespace           = "recordNamespace" // ???
  @transient lazy val converter = AvroConverter.createConverterToAvro(sqlSchema, recordName, recordNamespace)

  // Risk: This process is memory intensive. Might require thread-level buffers to optimize memory usage
  def rowToBytes(row: Row): Array[Byte] = {
    val genRecord = converter(row).asInstanceOf[GenericRecord]
    if (log.isDebugEnabled) log.debug(s"genRecord = $genRecord")
    val datumWriter   = new GenericDatumWriter[GenericRecord](_avroSchema)
    val avroEncoder   = EncoderFactory.get
    val byteArrOS     = new ByteArrayOutputStream(BufferSize)
    val binaryEncoder = avroEncoder.binaryEncoder(byteArrOS, null)
    datumWriter.write(genRecord, binaryEncoder)
    binaryEncoder.flush()
    byteArrOS.toByteArray
  }

  def encode(dataset: Dataset[T]): Dataset[Array[Byte]] =
    dataset.toDF().mapPartitions(rows ⇒ rows.map(rowToBytes)).as[Array[Byte]]

  // Note to self: I'm not sure how heavy this chain of transformations is
  def encodeWithKey(dataset: Dataset[T], keyFun: T ⇒ String): Dataset[EncodedKV] = {
    val encoder             = encoderFor[T]
    implicit val rowEncoder = RowEncoder(encoder.schema).resolveAndBind()
    dataset.map { value ⇒
      val key         = keyFun(value)
      val internalRow = encoder.toRow(value)
      val row         = rowEncoder.fromRow(internalRow)
      val bytes       = rowToBytes(row)
      EncodedKV(key, bytes)
    }
  }

} 
Example 154
Source File: JdbcConstants.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.jdbc.conf

import org.apache.spark.sql.types.{StringType, StructField, StructType}

import com.paypal.gimel.common.conf.GimelConstants

object JdbcConstants {

  // basic variable references
  val HDFS_PREFIX: String = "hdfs:///user"
  val TD_PASS_FILENAME_DEFAULT: String = "pass.dat"
  val P_FILEPATH = s"${HDFS_PREFIX}/${GimelConstants.USER_NAME}/password/teradata/${TD_PASS_FILENAME_DEFAULT}"
  val MYSQL = "MYSQL"
  val TERADATA = "TERADATA"
  val ORCALE = "ORACLE"
  val POSTGRESQL = "POSTGRESQL"

  val HDFS_FILE_SOURCE = "hdfs"
  val LOCAL_FILE_SOURCE = "local"
  val DEFAULT_P_FILE_SOURCE = HDFS_FILE_SOURCE
  val JDBC_FILE_PASSWORD_STRATEGY = "file"
  val JDBC_INLINE_PASSWORD_STRATEGY = "inline"
  val JDBC_PROXY_USERNAME = "gimelproxyuser"
  val JDBC_CUSTOM_PASSWORD_STRATEGY = "custom"
  val JDBC_DEFAULT_PASSWORD_STRATEGY = JDBC_CUSTOM_PASSWORD_STRATEGY
  val JDBC_AUTH_REQUEST_TYPE = "JDBC"

  // default TD properties
  val DEFAULT_TD_SESSIONS = 5
  val DEFAULT_CHARSET = "UTF16"
  val DEFAULT_SESSIONS = "6"
  val TD_FASTLOAD_KEY: String = "FASTLOAD"
  val TD_FASTEXPORT_KEY: String = "FASTEXPORT"
  val TD_FASTLOAD_KEY_LC: String = TD_FASTLOAD_KEY.toLowerCase
  val TD_FASTEXPORT_KEY_LC: String = TD_FASTEXPORT_KEY.toLowerCase

  // JDBC READ configs
  val MAX_TD_JDBC_READ_PARTITIONS = 24
  val MAX_FAST_EXPORT_READ_PARTITIONS = 2
  val DEFAULT_READ_FETCH_SIZE = 1000
  val DEFAULT_LOWER_BOUND = 0
  val DEFAULT_UPPER_BOUND = 20
  val DEFAULT_READ_TYPE = "BATCH"
  val READ_OPERATION = "read"

  // JDBC write configs
  val GIMEL_TEMP_PARTITION = "GIMEL_TEMP_PARTITION"
  val DEFAULT_WRITE_BATCH_SIZE = 10000
  val MAX_TD_JDBC_WRITE_PARTITIONS: Int = 24
  val MAX_FAST_LOAD_WRITE_PARTITIONS: Int = 2
  val DEFAULT_INSERT_STRATEGY = "insert"
  val DEFAULT_WRITE_TYPE = "BATCH"
  val WRITE_OPERATION = "write"
  val REPARTITION_METHOD = "repartition"
  val COALESCE_METHOD = "coalesce"

  // partitions for Systems other than Teradata
  val DEFAULT_JDBC_READ_PARTITIONS = 100
  val DEFAULT_JDBC_WRTIE_PARTITIONS = 100
  val DEFAULT_JDBC_PER_PROCESS_MAX_ROWS_LIMIT: Long = 1000000000L
  val DEFAULT_JDBC_PER_PROCESS_MAX_ROWS_LIMIT_STRING: String = DEFAULT_JDBC_PER_PROCESS_MAX_ROWS_LIMIT.toString

  // pushdown constants
  val DEF_JDBC_PUSH_DOWN_SCHEMA: StructType = new StructType(fields = Seq(
    StructField("QUERY_EXECUTION", StringType, nullable = false)
  ).toArray)
} 
Example 155
Source File: RestApiConsumer.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.restapi.reader

import scala.language.implicitConversions

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import com.paypal.gimel.common.gimelservices.GimelServiceUtilities
import com.paypal.gimel.logger.Logger
import com.paypal.gimel.restapi.conf.RestApiClientConfiguration



object RestApiConsumer {

  val logger: Logger = Logger()
  val utils: GimelServiceUtilities = GimelServiceUtilities()

  def consume(sparkSession: SparkSession, conf: RestApiClientConfiguration): DataFrame = {
    def MethodName: String = new Exception().getStackTrace().apply(1).getMethodName()
    logger.info(" @Begin --> " + MethodName)

    val responsePayload = conf.httpsFlag match {
      case false => utils.get(conf.resolvedUrl.toString)
      case true => utils.httpsGet(conf.resolvedUrl.toString)
    }
    conf.parsePayloadFlag match {
      case false =>
        logger.info("NOT Parsing payload.")
        val rdd: RDD[String] = sparkSession.sparkContext.parallelize(Seq(responsePayload))
        val rowRdd: RDD[Row] = rdd.map(Row(_))
        val field: StructType = StructType(Seq(StructField(conf.payloadFieldName, StringType)))
        sparkSession.sqlContext.createDataFrame(rowRdd, field)
      case true =>
        logger.info("Parsing payload to fields - as requested.")
        val rdd: RDD[String] = sparkSession.sparkContext.parallelize(Seq(responsePayload))
        sparkSession.sqlContext.read.json(rdd)
    }
  }

} 
Example 156
Source File: PulsarRelation.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType


private[pulsar] class PulsarRelation(
    override val sqlContext: SQLContext,
    override val schema: StructType,
    schemaInfo: SchemaInfoSerializable,
    adminUrl: String,
    clientConf: ju.Map[String, Object],
    readerConf: ju.Map[String, Object],
    startingOffset: SpecificPulsarOffset,
    endingOffset: SpecificPulsarOffset,
    pollTimeoutMs: Int,
    failOnDataLoss: Boolean,
    subscriptionNamePrefix: String,
    jsonOptions: JSONOptionsInRead)
    extends BaseRelation
    with TableScan
    with Logging {

  import PulsarSourceUtils._

  val reportDataLoss = reportDataLossFunc(failOnDataLoss)

  override def buildScan(): RDD[Row] = {
    val fromTopicOffsets = startingOffset.topicOffsets
    val endTopicOffsets = endingOffset.topicOffsets

    if (fromTopicOffsets.keySet != endTopicOffsets.keySet) {
      val fromTopics = fromTopicOffsets.keySet.toList.sorted.mkString(",")
      val endTopics = endTopicOffsets.keySet.toList.sorted.mkString(",")
      throw new IllegalStateException(
        "different topics " +
          s"for starting offsets topics[${fromTopics}] and " +
          s"ending offsets topics[${endTopics}]")
    }

    val offsetRanges = endTopicOffsets.keySet
      .map { tp =>
        val fromOffset = fromTopicOffsets.getOrElse(tp, {
          // this shouldn't happen since we had checked it
          throw new IllegalStateException(s"$tp doesn't have a from offset")
        })
        val untilOffset = endTopicOffsets(tp)
        PulsarOffsetRange(tp, fromOffset, untilOffset, None)
      }
      .filter { range =>
        if (range.untilOffset.compareTo(range.fromOffset) < 0) {
          reportDataLoss(
            s"${range.topic}'s offset was changed " +
              s"from ${range.fromOffset} to ${range.untilOffset}, " +
              "some data might has been missed")
          false
        } else {
          true
        }
      }
      .toSeq

    val rdd = new PulsarSourceRDD4Batch(
      sqlContext.sparkContext,
      schemaInfo,
      adminUrl,
      clientConf,
      readerConf,
      offsetRanges,
      pollTimeoutMs,
      failOnDataLoss,
      subscriptionNamePrefix,
      jsonOptions
    )
    sqlContext.internalCreateDataFrame(rdd.setName("pulsar"), schema).rdd
  }
} 
Example 157
Source File: PulsarStreamWriter.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.{util => ju}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType


class PulsarStreamDataWriter(
    inputSchema: Seq[Attribute],
    clientConf: ju.Map[String, Object],
    producerConf: ju.Map[String, Object],
    topic: Option[String],
    adminUrl: String)
    extends PulsarRowWriter(inputSchema, clientConf, producerConf, topic, adminUrl)
    with DataWriter[InternalRow] {

  def write(row: InternalRow): Unit = {
    checkForErrors()
    sendRow(row)
  }

  def commit(): WriterCommitMessage = {
    // Send is asynchronous, but we can't commit until all rows are actually in Pulsar.
    // This requires flushing and then checking that no callbacks produced errors.
    // We also check for errors before to fail as soon as possible - the check is cheap.
    checkForErrors()
    producerFlush()
    checkForErrors()
    PulsarWriterCommitMessage
  }

  def abort(): Unit = {}

  def close(): Unit = {
    checkForErrors()
    producerClose()
    checkForErrors()
  }
} 
Example 158
Source File: SqlUtils.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.StructType

object SqlUtils {
  def convertToDF(sparkSession: SparkSession, plan : LogicalPlan): DataFrame = {
    Dataset.ofRows(sparkSession, plan)
  }

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: LogicalPlan, failIfUnresolved: Boolean,
                        exprName: Option[String] = None): Expression = {
    resolveReferences(sparkSession, expr, Seq(planContaining), failIfUnresolved, exprName)
  }

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: Seq[LogicalPlan],
                        failIfUnresolved: Boolean,
                        exprName: Option[String]): Expression = {
    val newPlan = FakeLogicalPlan(expr, planContaining)
    val resolvedExpr = sparkSession.sessionState.analyzer.execute(newPlan) match {
      case FakeLogicalPlan(resolvedExpr: Expression, _) =>
        // Return even if it did not successfully resolve
        resolvedExpr
      case _ =>
        expr
      // This is unexpected
    }
    if (failIfUnresolved) {
      resolvedExpr.flatMap(_.references).filter(!_.resolved).foreach {
        attr => {
          val failedMsg = exprName match {
            case Some(name) => s"${attr.sql} resolution in $name given these columns: "+
              planContaining.flatMap(_.output).map(_.name).mkString(",")
            case _ => s"${attr.sql} resolution failed given these columns: "+
              planContaining.flatMap(_.output).map(_.name).mkString(",")
          }
          attr.failAnalysis(failedMsg)
        }
      }
    }
    resolvedExpr
  }

  def hasSparkStopped(sparkSession: SparkSession): Boolean = {
    sparkSession.sparkContext.stopped.get()
  }

  
  def createDataFrameUsingAttributes(sparkSession: SparkSession,
                                     rdd: RDD[Row],
                                     schema: StructType,
                                     attributes: Seq[Attribute]): DataFrame = {
    val encoder = RowEncoder(schema)
    val catalystRows = rdd.map(encoder.toRow)
    val logicalPlan = LogicalRDD(
      attributes,
      catalystRows,
      isStreaming = false)(sparkSession)
    Dataset.ofRows(sparkSession, logicalPlan)
  }

  def analysisException(cause: String): Throwable = {
    new AnalysisException(cause)
  }
}

case class FakeLogicalPlan(expr: Expression, children: Seq[LogicalPlan])
  extends LogicalPlan {
  override def output: Seq[Attribute] = children.foldLeft(Seq[Attribute]())((out, child) => out ++ child.output)
} 
Example 159
Source File: Executor.scala    From neo4j-spark-connector   with Apache License 2.0 5 votes vote down vote up
package org.neo4j.spark

import java.time.{LocalDate, LocalDateTime, OffsetTime, ZoneOffset, ZonedDateTime}
import java.util
import java.sql.Timestamp

import org.apache.spark.SparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.StructType
import org.neo4j.spark.dataframe.CypherTypes
import org.neo4j.spark.utils.{Neo4jSessionAwareIterator, Neo4jUtils}

import scala.collection.JavaConverters._


object Executor {

  def convert(value: AnyRef): Any = value match {
    case it: util.Collection[_] => it.toArray()
    case m: java.util.Map[_,_] => m.asScala
    case _ => Neo4jUtils.convert(value)
  }

  def toJava(parameters: Map[String, Any]): java.util.Map[String, Object] = {
    parameters.mapValues(toJava).asJava
  }

  private def toJava(x: Any): AnyRef = x match {
    case y: Seq[_] => y.asJava
    case _ => x.asInstanceOf[AnyRef]
  }

  val EMPTY = Array.empty[Any]

  val EMPTY_RESULT = new CypherResult(new StructType(), Iterator.empty)

  class CypherResult(val schema: StructType, val rows: Iterator[Array[Any]]) {
    def sparkRows: Iterator[Row] = rows.map(row => new GenericRowWithSchema(row, schema))

    def fields = schema.fieldNames
  }

  def execute(sc: SparkContext, query: String, parameters: Map[String, AnyRef]): CypherResult = {
    execute(Neo4jConfig(sc.getConf), query, parameters)
  }

  private def rows(result: Iterator[_]) = {
    var i = 0
    while (result.hasNext) i = i + 1
    i
  }

  def execute(config: Neo4jConfig, query: String, parameters: Map[String, Any], write: Boolean = false): CypherResult = {
    val result = new Neo4jSessionAwareIterator(config, query, toJava(parameters), write)
    if (!result.hasNext) {
      return EMPTY_RESULT
    }
    val peek = result.peek()
    val keyCount = peek.size()
    if (keyCount == 0) {
      return new CypherResult(new StructType(), Array.fill[Array[Any]](rows(result))(EMPTY).toIterator)
    }
    val keys = peek.keys().asScala
    val fields = keys.map(k => (k, peek.get(k).`type`())).map(keyType => CypherTypes.field(keyType))
    val schema = StructType(fields)
    val it = result.map(record => {
      val row = new Array[Any](keyCount)
      var i = 0
      while (i < keyCount) {
        val value = convert(record.get(i).asObject())
        row.update(i, value)
        i = i + 1
      }
      row
    })
    new CypherResult(schema, it)
  }
} 
Example 160
Source File: DefaultSource.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.gdb

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType


  override def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String],
                              schema: StructType
                             ): BaseRelation = {
    val path = parameters.getOrElse("path", sys.error("Parameter 'path' must be defined."))
    val name = parameters.getOrElse("name", sys.error("Parameter 'name' must be defined."))
    val numPartitions = parameters.getOrElse("numPartitions", "8").toInt
    GDBRelation(path, name, numPartitions)(sqlContext)
  }
} 
Example 161
Source File: GDBRowIterator.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.gdb

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType


class GDBRowIterator(indexIter: Iterator[IndexInfo], dataBuffer: DataBuffer, fields: Array[Field], schema: StructType)
  extends Iterator[Row] with Serializable {

  val numFieldsWithNullAllowed = fields.count(_.nullable)
  val nullValueMasks = new Array[Byte]((numFieldsWithNullAllowed / 8.0).ceil.toInt)

  def hasNext() = indexIter.hasNext

  def next() = {
    val index = indexIter.next()
    val numBytes = dataBuffer.seek(index.seek).readBytes(4).getInt
    val byteBuffer = dataBuffer.readBytes(numBytes)
    0 until nullValueMasks.length foreach (nullValueMasks(_) = byteBuffer.get)
    var bit = 0
    val values = fields.map(field => {
      if (field.nullable) {
        val i = bit >> 3
        val m = 1 << (bit & 7)
        bit += 1
        if ((nullValueMasks(i) & m) == 0) {
          field.readValue(byteBuffer, index.objectID)
        }
        else {
          null // TODO - Do not like null here - but...it is nullable !
        }
      } else {
        field.readValue(byteBuffer, index.objectID)
      }
    }
    )
    new GenericRowWithSchema(values, schema)
  }
} 
Example 162
Source File: TransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

class TransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)
  ))

  val input = Seq(
    Row("cnn.com", 1),
    Row("bbc.com", 2),
    Row("hbo.com", 1),
    Row("mashable.com", 3)
  )

  val isActive = true
  val withAllOther = true

  // Can't call 'day_of_week' with String
  val badFunctionType = ModelFeature(isActive, "advertisement", "f1", "day_of_week(adv_site, 'UTC')", Top(95.0, allOther = false))

  // Not enough parameters for 'concat'
  val wrongParametersCount = ModelFeature(isActive, "advertisement", "f2", "concat(adv_site)", Top(95.0, allOther = false))

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)

  "Transformer" should "report failed feature extraction" in {
    val features = Transformer.extractFeatures(df, Seq(badFunctionType, wrongParametersCount))
    assert(features.isLeft)
    val errors = features.fold(identity, _ => sys.error("Should not be here"))

    assert(errors.length == 2)
    assert(errors(0).feature == badFunctionType)
    assert(errors(1).feature == wrongParametersCount)
  }

} 
Example 163
Source File: IdentityTransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

import scalaz.syntax.either._
import scalaz.{-\/, \/-}

class IdentityTransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)
  ))

  val input = Seq(
      Row("cnn.com", 1),
      Row("bbc.com", 2),
      Row("hbo.com", 1),
      Row("mashable.com", 3)
  )

  val isActive = true
  val withAllOther = true

  val adSite = ModelFeature(isActive, "Ad", "ad_site", "adv_site", Identity)
  val adId = ModelFeature(isActive, "Ad", "ad_id", "adv_id", Identity)

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)
  val transformer = new IdentityTransformer(Transformer.extractFeatures(df, Seq(adSite, adId)) match {
    case -\/(err) => sys.error(s"Can't extract features: $err")
    case \/-(suc) => suc
  })

  "Identity Transformer" should "support integer typed model feature" in {
    val valid = transformer.validate(adId)
    assert(valid == TypedModelFeature(adId, IntegerType).right)
  }

  it should "fail if feature column doesn't exists" in {
    val failed = transformer.validate(adSite.copy(feature = "adv_site"))
    assert(failed == FeatureTransformationError.FeatureColumnNotFound("adv_site").left)
  }

  it should "fail if column type is not supported" in {
    val failed = transformer.validate(adSite)
    assert(failed == FeatureTransformationError.UnsupportedTransformDataType("ad_site", StringType, Identity).left)
  }

} 
Example 164
Source File: DataFrameNaFunctionsSpec.scala    From spark-spec   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.spec.sql

import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.Row
import org.scalatest._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._

import com.github.mrpowers.spark.spec.SparkSessionTestWrapper

import com.github.mrpowers.spark.fast.tests.DatasetComparer

class DataFrameNaFunctionsSpec extends FunSpec with SparkSessionTestWrapper with DatasetComparer {

  import spark.implicits._

  describe("#drop") {

    it("drops rows that contains null values") {

      val sourceData = List(
        Row(1, null),
        Row(null, null),
        Row(3, 30),
        Row(10, 20)
      )

      val sourceSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)
      )

      val sourceDF = spark.createDataFrame(
        spark.sparkContext.parallelize(sourceData),
        StructType(sourceSchema)
      )

      val actualDF = sourceDF.na.drop()

      val expectedData = List(
        Row(3, 30),
        Row(10, 20)
      )

      val expectedSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)
      )

      val expectedDF = spark.createDataFrame(
        spark.sparkContext.parallelize(expectedData),
        StructType(expectedSchema)
      )

      assertSmallDatasetEquality(actualDF, expectedDF)

    }

  }

  describe("#fill") {

    it("Returns a new DataFrame that replaces null or NaN values in numeric columns with value") {

      val sourceDF = spark.createDF(
        List(
          (1, null),
          (null, null),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, true),
          ("num2", IntegerType, true)
        )
      )

      val actualDF = sourceDF.na.fill(77)

      val expectedDF = spark.createDF(
        List(
          (1, 77),
          (77, 77),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, false),
          ("num2", IntegerType, false)
        )
      )

      assertSmallDatasetEquality(actualDF, expectedDF)

    }

  }

  describe("#replace") {
    pending
  }

} 
Example 165
Source File: DefaultSource.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{ HadoopFsRelation, HadoopFsRelationProvider }
import org.apache.spark.sql.types.StructType

class DefaultSource extends HadoopFsRelationProvider {

  // override def shortName(): String = "ply"

  override def createRelation(
    sqlContext: SQLContext,
    paths: Array[String],
    dataSchema: Option[StructType],
    partitionColumns: Option[StructType],
    parameters: Map[String, String]
  ): HadoopFsRelation = {
    new PlyRelation(paths, dataSchema, partitionColumns, parameters)(sqlContext)
  }
} 
Example 166
Source File: package.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus

import org.apache.spark.sql.{ SQLContext, DataFrameReader, DataFrameWriter, DataFrame, Row }
import org.apache.spark.sql.types.StructType

package object ply {

  
  implicit class PlyDataFrameReader(reader: DataFrameReader) {
    def ply: String => DataFrame = reader.format("fr.ign.spark.iqmulus.ply").load
  }

  implicit class PlyDataFrame(df: DataFrame) {
    def saveAsPly(location: String, littleEndian: Boolean = true) = {
      val df_id = df.drop("pid").drop("fid")
      val schema = df_id.schema
      val saver = (key: Int, iter: Iterator[Row]) =>
        Iterator(iter.saveAsPly(s"$location/$key.ply", schema, littleEndian))
      df_id.rdd.mapPartitionsWithIndex(saver, true).collect
    }
  }

  implicit class PlyRowIterator(iter: Iterator[Row]) {
    def saveAsPly(
      filename: String,
      schema: StructType,
      littleEndian: Boolean
    ) = {
      val path = new org.apache.hadoop.fs.Path(filename)
      val fs = path.getFileSystem(new org.apache.hadoop.conf.Configuration)
      val f = fs.create(path)
      val rows = iter.toArray
      val count = rows.size.toLong
      val header = new PlyHeader(filename, littleEndian, Map("vertex" -> ((count, schema))))
      val dos = new java.io.DataOutputStream(f);
      dos.write(header.toString.getBytes)
      val ros = new RowOutputStream(dos, littleEndian, schema)
      rows.foreach(ros.write)
      dos.close
      header
    }
  }
} 
Example 167
Source File: DefaultSource.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.xyz

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{ HadoopFsRelation, HadoopFsRelationProvider }
import org.apache.spark.sql.types.StructType

class DefaultSource extends HadoopFsRelationProvider {

  // override def shortName(): String = "xyz"

  override def createRelation(
    sqlContext: SQLContext,
    paths: Array[String],
    dataSchema: Option[StructType],
    partitionColumns: Option[StructType],
    parameters: Map[String, String]
  ): HadoopFsRelation = {
    new XyzRelation(paths, dataSchema, partitionColumns, parameters)(sqlContext)
  }
} 
Example 168
Source File: package.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus

import org.apache.spark.sql.{ SQLContext, DataFrameReader, DataFrameWriter, DataFrame, Row }
import org.apache.spark.sql.types.{ FloatType, StructType }

package object xyz {

  
  implicit class XyzDataFrameReader(reader: DataFrameReader) {
    def xyz: String => DataFrame = reader.format("fr.ign.spark.iqmulus.xyz").load
  }

  implicit class XyzDataFrame(df: DataFrame) {
    def saveAsXyz(location: String) = {
      val df_id = df.drop("id")
      require(df_id.schema.fieldNames.take(3) sameElements Array("x", "y", "z"))
      require(df_id.schema.fields.map(_.dataType).take(3).forall(_ == FloatType))
      val saver = (key: Int, iter: Iterator[Row]) => Iterator(iter.saveXyz(s"$location/$key.xyz"))
      df_id.rdd.mapPartitionsWithIndex(saver, true).collect
    }
  }

  implicit class XyzRowIterator(iter: Iterator[Row]) {
    def saveXyz(filename: String) = {
      val path = new org.apache.hadoop.fs.Path(filename)
      val fs = path.getFileSystem(new org.apache.hadoop.conf.Configuration)
      val f = fs.create(path)
      val dos = new java.io.DataOutputStream(f)
      var count = 0L
      iter.foreach(row => { count += 1; dos.writeBytes(row.mkString("", "\t", "\n")) })
      dos.close
      (filename, count)
    }
  }
} 
Example 169
Source File: DefaultSource.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.las

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{ HadoopFsRelation, HadoopFsRelationProvider }
import org.apache.spark.sql.types.StructType

class DefaultSource extends HadoopFsRelationProvider {

  // override def shortName(): String = "las"

  override def createRelation(
    sqlContext: SQLContext,
    paths: Array[String],
    dataSchema: Option[StructType],
    partitionColumns: Option[StructType],
    parameters: Map[String, String]
  ): HadoopFsRelation = {
    new LasRelation(paths, dataSchema, partitionColumns, parameters)(sqlContext)
  }
} 
Example 170
Source File: package.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus

import org.apache.spark.sql.{ SQLContext, DataFrameReader, DataFrameWriter, DataFrame }
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Row

package object las {

  
  implicit class LasDataFrameReader(reader: DataFrameReader) {
    def las: String => DataFrame = reader.format("fr.ign.spark.iqmulus.las").load
  }

  implicit class LasDataFrame(df: DataFrame) {
    def saveAsLas(
      location: String,
      formatOpt: Option[Byte] = None,
      version: Version = Version(),
      scale: Array[Double] = Array(0.01, 0.01, 0.01),
      offset: Array[Double] = Array(0, 0, 0)
    ) = {
      val format = formatOpt.getOrElse(LasHeader.formatFromSchema(df.schema))
      val schema = LasHeader.schema(format) // no user types for now
      val cols = schema.fieldNames.intersect(df.schema.fieldNames)
      val saver = (key: Int, iter: Iterator[Row]) =>
        Iterator(iter.saveAsLas(s"$location/$key.las", schema, format, scale, offset, version))
      df.select(cols.head, cols.tail: _*).rdd.mapPartitionsWithIndex(saver, true).collect
    }
  }

  implicit class LasRowIterator(iter: Iterator[Row]) {
    def saveAsLas(
      filename: String, schema: StructType, format: Byte,
      scale: Array[Double], offset: Array[Double], version: Version = Version()
    ) = {
      // materialize the partition to access it in a single pass, TODO workaround that 
      val rows = iter.toArray
      val count = rows.length.toLong
      val pmin = Array.fill[Double](3)(Double.PositiveInfinity)
      val pmax = Array.fill[Double](3)(Double.NegativeInfinity)
      val countByReturn = Array.fill[Long](15)(0)
      rows.foreach { row =>
        val x = offset(0) + scale(0) * row.getAs[Int]("x").toDouble
        val y = offset(1) + scale(1) * row.getAs[Int]("y").toDouble
        val z = offset(2) + scale(2) * row.getAs[Int]("z").toDouble
        val ret = row.getAs[Byte]("flags") & 0x3
        countByReturn(ret) += 1
        pmin(0) = Math.min(pmin(0), x)
        pmin(1) = Math.min(pmin(1), y)
        pmin(2) = Math.min(pmin(2), z)
        pmax(0) = Math.max(pmax(0), x)
        pmax(1) = Math.max(pmax(1), y)
        pmax(2) = Math.max(pmax(2), z)
      }
      val path = new org.apache.hadoop.fs.Path(filename)
      val fs = path.getFileSystem(new org.apache.hadoop.conf.Configuration)
      val f = fs.create(path)
      val header = new LasHeader(filename, format, count, pmin, pmax, scale, offset,
        version = version, pdr_return_nb = countByReturn)
      val dos = new java.io.DataOutputStream(f);
      header.write(dos)
      val ros = new RowOutputStream(dos, littleEndian = true, schema)
      rows.foreach(ros.write)
      dos.close
      header
    }
  }
} 
Example 171
Source File: DataFrameConverterSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row}
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfterAll, FunSpec, Matchers}
import play.api.libs.json.{JsArray, JsString, Json}
import test.utils.SparkContextProvider

import scala.collection.mutable

class DataFrameConverterSpec extends FunSpec with MockitoSugar with Matchers with BeforeAndAfterAll {

  lazy val spark = SparkContextProvider.sparkContext

  override protected def afterAll(): Unit = {
    spark.stop()
    super.afterAll()
  }

  val dataFrameConverter: DataFrameConverter = new DataFrameConverter
  val mockDataFrame = mock[DataFrame]
  val mockRdd = spark.parallelize(Seq(Row(new mutable.WrappedArray.ofRef(Array("test1", "test2")), 2, null)))
  val mockStruct = mock[StructType]
  val columns = Seq("foo", "bar").toArray

  doReturn(mockStruct).when(mockDataFrame).schema
  doReturn(columns).when(mockStruct).fieldNames
  doReturn(mockRdd).when(mockDataFrame).rdd

  describe("DataFrameConverter") {
    describe("#convert") {
      it("should convert to a valid JSON object") {
        val someJson = dataFrameConverter.convert(mockDataFrame, "json")
        val jsValue = Json.parse(someJson.get)
        jsValue \ "columns" should be (JsArray(Seq(JsString("foo"), JsString("bar"))))
        jsValue \ "rows" should be (JsArray(Seq(
          JsArray(Seq(JsString("[test1, test2]"), JsString("2"), JsString("null")))
        )))
      }
      it("should convert to csv") {
        val csv = dataFrameConverter.convert(mockDataFrame, "csv").get
        val values = csv.split("\n")
        values(0) shouldBe "foo,bar"
        values(1) shouldBe "[test1, test2],2,null"
      }
      it("should convert to html") {
        val html = dataFrameConverter.convert(mockDataFrame, "html").get
        html.contains("<th>foo</th>") should be(true)
        html.contains("<th>bar</th>") should be(true)
        html.contains("<td>[test1, test2]</td>") should be(true)
        html.contains("<td>2</td>") should be(true)
        html.contains("<td>null</td>") should be(true)
      }
      it("should convert limit the selection") {
        val someLimited = dataFrameConverter.convert(mockDataFrame, "csv", 1)
        val limitedLines = someLimited.get.split("\n")
        limitedLines.length should be(2)
      }
      it("should return a Failure for invalid types") {
        val result = dataFrameConverter.convert(mockDataFrame, "Invalid Type")
        result.isFailure should be(true)
      }
    }
  }
} 
Example 172
Source File: DataFrameCreation.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
// Putting this in this Spark package to be able to access internalCreateDataFrame
// See my (currently unanswered) SO post for context:
// https://stackoverflow.com/questions/56183811/how-to-create-a-custom-structured-streaming-source-for-apache-spark-2-3-0
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType


object DataFrameCreation {
  def createStreamingDataFrame(sqlContext: SQLContext,
                                rdd: RDD[Row],
                                schema: StructType): DataFrame = {
    // internalCreateDataFrame requires an RDD[InternalRow]
    val encoder = RowEncoder.apply(schema)
    val encoded: RDD[InternalRow] = rdd.map(row => {
      encoder.toRow(row)
    })

    sqlContext.internalCreateDataFrame(encoded, schema, isStreaming = true)
  }

  def createStreamingDataFrame(sqlContext: SQLContext,
                               df: DataFrame,
                               schema: StructType): DataFrame = {
    // internalCreateDataFrame requires an RDD[InternalRow]
    val encoder = RowEncoder.apply(schema)
    val encoded: RDD[InternalRow] = df.rdd.map(row => {
      encoder.toRow(row)
    })
    sqlContext.internalCreateDataFrame(encoded, schema, isStreaming = true)
  }
} 
Example 173
Source File: DataFrameReaderFunctions.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.sql

import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, DataFrameReader}

class DataFrameReaderFunctions(@transient val dfr: DataFrameReader) extends Serializable {

  
  private def buildFrame(options: Map[String, String] = null, schema: StructType = null,
    schemaFilter: Option[Filter] = null): DataFrame = {
    val builder = dfr
      .format(source)
      .schema(schema)

    val filter = schemaFilter.map(N1QLRelation.filterToExpression)
    if (filter.isDefined) {
      builder.option("schemaFilter", filter.get)
    }

    if (options != null) {
      builder.options(options)
    }

    builder.load()
  }

} 
Example 174
Source File: N1qlSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.n1ql

import com.couchbase.client.core.CouchbaseException
import com.couchbase.client.java.error.QueryExecutionException
import com.couchbase.client.java.query.N1qlQuery
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.sql.sources.EqualTo
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.scalatest._
import com.couchbase.spark._
import com.couchbase.spark.connection.CouchbaseConnection
import com.couchbase.spark.sql.N1QLRelation
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import scala.util.control.NonFatal

class N1qlSpec extends FunSuite with Matchers with BeforeAndAfterAll {

  private val master = "local[2]"
  private val appName = "cb-int-specs1"

  private var spark: SparkSession = _


  override def beforeAll(): Unit = {
    spark = SparkSession
      .builder()
      .master(master)
      .appName(appName)
      .config("spark.couchbase.username", "Administrator")
      .config("spark.couchbase.password", "password")
      // Open 2 buckets as tests below rely on it
      .config("com.couchbase.bucket.default", "")
      .config("com.couchbase.bucket.travel-sample", "")
      .getOrCreate()
  }

  override def afterAll(): Unit = {
    CouchbaseConnection().stop()
    spark.stop()
  }

  test("Creating N1QLRelation with default bucket, when two buckets exist, should fail") {
    assertThrows[IllegalStateException] {
      spark.read
        .format("com.couchbase.spark.sql.DefaultSource")
        .option("schemaFilter", N1QLRelation.filterToExpression(EqualTo("type", "airline")))
        .option("schemaFilter", "`type` = 'airline'")
        .schema(StructType(StructField("name", StringType) :: Nil))
        .load()
    }
  }

  test("Creating N1QLRelation with non-default bucket, when two buckets exist, should succeed") {
    spark.read
      .format("com.couchbase.spark.sql.DefaultSource")
      .option("schemaFilter", N1QLRelation.filterToExpression(EqualTo("type", "airline")))
      .option("schemaFilter", "`type` = 'airline'")
      .option("bucket", "travel-sample")
      .schema(StructType(StructField("name", StringType) :: Nil))
      .load()
  }

  test("N1QL failures should fail the Observable") {
    try {
      spark.sparkContext
        .couchbaseQuery(N1qlQuery.simple("BAD QUERY"), bucketName = "default")
        .collect()
        .foreach(println)
      fail()
    }
    catch {
      case e: SparkException =>
        assert (e.getCause.isInstanceOf[QueryExecutionException])
        val err = e.getCause.asInstanceOf[QueryExecutionException]
        assert (err.getMessage == "syntax error - at QUERY")
      case NonFatal(e) =>
        println(e)
        fail()
    }
  }
} 
Example 175
Source File: CouchbaseDataFrameSpec.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.couchbase.spark.sql

import com.couchbase.spark.connection.CouchbaseConnection
import org.apache.avro.generic.GenericData.StringType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode, SparkSession}
import org.apache.spark.sql.sources.EqualTo
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.{SparkConf, SparkContext}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CouchbaseDataFrameSpec extends FlatSpec with Matchers with BeforeAndAfterAll {

  private val master = "local[2]"
  private val appName = "cb-int-specs1"

  private var spark: SparkSession = null


  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .setMaster(master)
      .setAppName(appName)
      .set("spark.couchbase.nodes", "127.0.0.1")
      .set("com.couchbase.username", "Administrator")
      .set("com.couchbase.password", "password")
      .set("com.couchbase.bucket.default", "")
      .set("com.couchbase.bucket.travel-sample", "")
    spark = SparkSession.builder().config(conf).getOrCreate()

    loadData()
  }

  override def afterAll(): Unit = {
    CouchbaseConnection().stop()
    spark.stop()
  }

  def loadData(): Unit = {

  }

  "If two buckets are used and the bucket is specified the API" should
    "not fail" in {
    val ssc = spark.sqlContext
    ssc.read.couchbase(EqualTo("type", "airline"), Map("bucket" -> "travel-sample"))
  }

  "The DataFrame API" should "infer the schemas" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    val airline = ssc.read.couchbase(EqualTo("type", "airline"), Map("bucket" -> "travel-sample"))
    val airport = ssc.read.couchbase(EqualTo("type", "airport"), Map("bucket" -> "travel-sample"))
    val route = ssc.read.couchbase(EqualTo("type", "route"), Map("bucket" -> "travel-sample"))
    val landmark = ssc.read.couchbase(EqualTo("type", "landmark"), Map("bucket" -> "travel-sample"))


    airline
      .limit(10)
      .write
      .mode(SaveMode.Overwrite)
      .couchbase(Map("bucket" -> "default"))

    // TODO: validate schemas which are inferred on a field and type basis

  }

  it should "write and ignore" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    // create df, write it twice
    val data = ("Michael", 28, true)
    val df = ssc.createDataFrame(spark.sparkContext.parallelize(Seq(data)))

    df.write
      .mode(SaveMode.Ignore)
      .couchbase(options = Map("idField" -> "_1", "bucket" -> "default"))

    df.write
      .mode(SaveMode.Ignore)
      .couchbase(options = Map("idField" -> "_1", "bucket" -> "default"))
  }

  it should "filter based on a function" in {
    val ssc = spark.sqlContext
    import com.couchbase.spark.sql._

    val airlineBySubstrCountry: DataFrame = ssc.read.couchbase(
      EqualTo("'substr(country, 0, 6)'", "United"), Map("bucket" -> "travel-sample"))

    airlineBySubstrCountry.count() should equal(6797)
  }

} 
Example 176
Source File: TestData.scala    From drunken-data-quality   with Apache License 2.0 5 votes vote down vote up
package de.frosner.ddq.testutils

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object TestData {

  def makeIntegerDf(spark: SparkSession, numbers: Seq[Int]): DataFrame =
    spark.createDataFrame(
      spark.sparkContext.makeRDD(numbers.map(Row(_))),
      StructType(List(StructField("column", IntegerType, nullable = false)))
    )

  def makeNullableStringDf(spark: SparkSession, strings: Seq[String]): DataFrame =
    spark.createDataFrame(spark.sparkContext.makeRDD(strings.map(Row(_))), StructType(List(StructField("column", StringType, nullable = true))))

  def makeIntegersDf(spark: SparkSession, row1: Seq[Int], rowN: Seq[Int]*): DataFrame = {
    val rows = row1 :: rowN.toList
    val numCols = row1.size
    val rdd = spark.sparkContext.makeRDD(rows.map(Row(_:_*)))
    val schema = StructType((1 to numCols).map(idx => StructField("column" + idx, IntegerType, nullable = false)))
    spark.createDataFrame(rdd, schema)
  }

} 
Example 177
Source File: XGBoostInference.scala    From xgbspark-text-classification   with Apache License 2.0 5 votes vote down vote up
package com.lenovo.ml


import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType
import DataPreprocess.segWords
import org.apache.spark.ml.PipelineModel

object XGBoostInference {
  def main(args:Array[String]): Unit = {
    // 1、创建Spark程序入口
    val sparkSession = SparkSession.builder().appName("XGBoostInference").enableHiveSupport().getOrCreate()

    // 2、读取训练数据,对文本预处理后分词
    val tableName = args(0)
    val matrix = sparkSession.sql("SELECT * FROM " + tableName)
    val words = segWords(sparkSession, args(1), args(2), args(3), args(4), matrix.select("text"))

    // 3、将原数据与分词结果关联起来
    val rows = matrix.rdd.zip(words.rdd).map{
      case (rowLeft, rowRight) => Row.fromSeq(rowLeft.toSeq ++ rowRight.toSeq)
    }
    val schema = StructType(matrix.schema.fields ++ words.schema.fields)
    val matrixMerge = sparkSession.createDataFrame(rows, schema)

    // 4、构建特征向量
    val featuredModelTrained = sparkSession.sparkContext.broadcast(PipelineModel.read.load(args(5)))
    val dataPrepared = featuredModelTrained.value.transform(matrixMerge).repartition(18).cache()

    // 5、加载分类模型,产出故障预测结果
    val xgbModelTrained = sparkSession.sparkContext.broadcast(PipelineModel.read.load(args(6)))
    val prediction = xgbModelTrained.value.transform(dataPrepared)

    // 6、将预测结果写到HDFS
    prediction.select("text", "predictedLabel", "probabilities").rdd.coalesce(1).saveAsTextFile(args(7))

    sparkSession.stop()
  }
} 
Example 178
Source File: SparkSupport.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml

import com.ibm.aardpfark.avro.SchemaConverters
import com.ibm.aardpfark.pfa.document.{PFADocument, ToPFA}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.sql.types.StructType


object SparkSupport {


  def toPFA(t: Transformer, pretty: Boolean): String = {
    toPFATransformer(t).pfa.toJSON(pretty)
  }

  def toPFA(p: PipelineModel, s: StructType, pretty: Boolean): String = {
    val inputFields = s.map { f => f.copy(nullable = false) }
    val inputSchema = StructType(inputFields)
    val pipelineInput = SchemaBuilder.record(s"Input_${p.uid}")
    val inputAvroSchema = SchemaConverters.convertStructToAvro(inputSchema, pipelineInput, "")
    Merge.mergePipeline(p, inputAvroSchema).toJSON(pretty)
  }

  // testing implicit conversions for Spark ML PipelineModel and Transformer to PFA / JSON

  implicit private[aardpfark] def toPFATransformer(transformer: org.apache.spark.ml.Transformer): ToPFA = {

    val pkg = transformer.getClass.getPackage.getName
    val name = transformer.getClass.getSimpleName
    val pfaPkg = pkg.replace("org.apache", "com.ibm.aardpfark")
    val pfaClass = Class.forName(s"$pfaPkg.PFA$name")

    val ctor = pfaClass.getConstructors()(0)
    ctor.newInstance(transformer).asInstanceOf[ToPFA]
  }
} 
Example 179
Source File: SparkFeaturePFASuiteBase.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.pfa

import com.opendatagroup.hadrian.jvmcompiler.PFAEngine
import org.json4s.DefaultFormats

import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.sql.types.StructType


abstract class SparkPipelinePFASuiteBase[A <: Result](implicit m: Manifest[A])
  extends SparkPredictorPFASuiteBase[A] {
  import com.ibm.aardpfark.spark.ml.SparkSupport._

  protected val schema: StructType

  override protected def transformerToPFA(t: Transformer, pretty: Boolean): String = {
    toPFA(t.asInstanceOf[PipelineModel], schema, pretty)
  }
}

abstract class SparkFeaturePFASuiteBase[A <: Result](implicit m: Manifest[A])
  extends SparkPFASuiteBase {

  implicit val formats = DefaultFormats

  protected var isDebug = false

  import com.ibm.aardpfark.spark.ml.SparkSupport._
  import org.json4s._
  import org.json4s.native.JsonMethods._

  test("PFA transformer produces the same results as Spark transformer") {
    parityTest(sparkTransformer, input, expectedOutput)
  }

  protected def transformerToPFA(t: Transformer, pretty: Boolean): String = {
    toPFA(t, pretty)
  }

  protected def testInputVsExpected(
      engine: PFAEngine[AnyRef, AnyRef],
      input: Array[String],
      expectedOutput: Array[String]) = {
    import ApproxEquality._
    input.zip(expectedOutput).foreach { case (in, out) =>
      val pfaResult = engine.action(engine.jsonInput(in))
      val actual = parse(pfaResult.toString).extract[A]
      val expected = parse(out).extract[A]
      (actual, expected) match {
        case (a: ScalerResult, e: ScalerResult) => assert(a.scaled === e.scaled)
        case (a: Result, e: Result) => assert(a === e)
      }
    }
  }

  def parityTest(
      sparkTransformer: Transformer,
      input: Array[String],
      expectedOutput: Array[String]): Unit = {
    val PFAJson = transformerToPFA(sparkTransformer, pretty = true)
    if (isDebug) {
      println(PFAJson)
    }
    val engine = getPFAEngine(PFAJson)
    testInputVsExpected(engine, input, expectedOutput)
  }
}

case class ScalerResult(scaled: Seq[Double]) extends Result 
Example 180
Source File: Kudu.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp

import org.apache.kudu.spark.kudu._
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


class KuduSink(master: String, database: String, checkpointLocation: String => String) {

    def writeTable(sinkName: String, triggerSeconds: Int = 10) =
      new Sink {
        override def createDataStreamWriter(df: DataFrame): DataStreamWriter[Row] = {
          val fullTableName = s"impala::$database.$name"
          df
            .writeStream
            .format("kudu")
            .option("kudu.master", master)
            .option("kudu.table", fullTableName)
            .option("checkpointLocation", checkpointLocation(name))
            .option("retries", "3")
            .outputMode("update")
        }

        override val name: String = sinkName
      }

} 
Example 181
Source File: SpreadsheetRelation.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService.SparkSpreadsheetContext
import com.github.potix2.spark.google.spreadsheets.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService._

  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheets] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName =>
      StructField(fieldName, StringType, nullable = true)
    })

} 
Example 182
Source File: Util.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import com.google.api.services.sheets.v4.model.{ExtendedValue, CellData, RowData}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DataTypes, StructType}

import scala.collection.JavaConverters._

object Util {
  def convert(schema: StructType, row: Row): Map[String, Object] =
    schema.iterator.zipWithIndex.map { case (f, i) => f.name -> row(i).asInstanceOf[AnyRef]} toMap

  def toRowData(row: Row): RowData =
      new RowData().setValues(
        row.schema.fields.zipWithIndex.map { case (f, i) =>
          new CellData()
            .setUserEnteredValue(
              f.dataType match {
                case DataTypes.StringType => new ExtendedValue().setStringValue(row.getString(i))
                case DataTypes.LongType => new ExtendedValue().setNumberValue(row.getLong(i).toDouble)
                case DataTypes.IntegerType => new ExtendedValue().setNumberValue(row.getInt(i).toDouble)
                case DataTypes.FloatType => new ExtendedValue().setNumberValue(row.getFloat(i).toDouble)
                case DataTypes.BooleanType => new ExtendedValue().setBoolValue(row.getBoolean(i))
                case DataTypes.DateType => new ExtendedValue().setStringValue(row.getDate(i).toString)
                case DataTypes.ShortType => new ExtendedValue().setNumberValue(row.getShort(i).toDouble)
                case DataTypes.TimestampType => new ExtendedValue().setStringValue(row.getTimestamp(i).toString)
                case DataTypes.DoubleType => new ExtendedValue().setNumberValue(row.getDouble(i))
              }
            )
        }.toList.asJava
      )

} 
Example 183
Source File: DefaultSource.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import java.io.File

import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
  final val DEFAULT_CREDENTIAL_PATH = "/etc/gdata/credential.p12"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = {
    createRelation(sqlContext, parameters, null)
  }

  private[spreadsheets] def pathToSheetNames(parameters: Map[String, String]): (String, String) = {
    val path = parameters.getOrElse("path", sys.error("'path' must be specified for spreadsheets."))
    val elems = path.split('/')
    if (elems.length < 2)
      throw new Exception("'path' must be formed like '<spreadsheet>/<worksheet>'")

    (elems(0), elems(1))
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType) = {
    val (spreadsheetName, worksheetName) = pathToSheetNames(parameters)
    val context = createSpreadsheetContext(parameters)
    createRelation(sqlContext, context, spreadsheetName, worksheetName, schema)
  }


  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val (spreadsheetName, worksheetName) = pathToSheetNames(parameters)
    implicit val context = createSpreadsheetContext(parameters)
    val spreadsheet = SparkSpreadsheetService.findSpreadsheet(spreadsheetName)
    if(!spreadsheet.isDefined)
      throw new RuntimeException(s"no such a spreadsheet: $spreadsheetName")

    spreadsheet.get.addWorksheet(worksheetName, data.schema, data.collect().toList, Util.toRowData)
    createRelation(sqlContext, context, spreadsheetName, worksheetName, data.schema)
  }

  private[spreadsheets] def createSpreadsheetContext(parameters: Map[String, String]) = {
    val serviceAccountIdOption = parameters.get("serviceAccountId")
    val credentialPath = parameters.getOrElse("credentialPath", DEFAULT_CREDENTIAL_PATH)
    SparkSpreadsheetService(serviceAccountIdOption, new File(credentialPath))
  }

  private[spreadsheets] def createRelation(sqlContext: SQLContext,
                                           context: SparkSpreadsheetService.SparkSpreadsheetContext,
                                           spreadsheetName: String,
                                           worksheetName: String,
                                           schema: StructType): SpreadsheetRelation =
    if (schema == null) {
      createRelation(sqlContext, context, spreadsheetName, worksheetName, None)
    }
    else {
      createRelation(sqlContext, context, spreadsheetName, worksheetName, Some(schema))
    }

  private[spreadsheets] def createRelation(sqlContext: SQLContext,
                                           context: SparkSpreadsheetService.SparkSpreadsheetContext,
                                           spreadsheetName: String,
                                           worksheetName: String,
                                           schema: Option[StructType]): SpreadsheetRelation =
    SpreadsheetRelation(context, spreadsheetName, worksheetName, schema)(sqlContext)
} 
Example 184
Source File: SparkSpreadsheetServiceWriteSuite.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService.SparkSpreadsheet
import com.google.api.services.sheets.v4.model.{ExtendedValue, CellData, RowData}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FlatSpec}

import scala.collection.JavaConverters._

class SparkSpreadsheetServiceWriteSuite extends FlatSpec with BeforeAndAfter {
  private val serviceAccountId = "53797494708-ds5v22b6cbpchrv2qih1vg8kru098k9i@developer.gserviceaccount.com"
  private val testCredentialPath = "src/test/resources/spark-google-spreadsheets-test-eb7b191d1e1d.p12"
  private val TEST_SPREADSHEET_NAME = "WriteSuite"
  private val TEST_SPREADSHEET_ID = "163Ja2OWUephWjIa-jpwTlvGcg8EJwCFCfxrF7aI117s"

  private val context: SparkSpreadsheetService.SparkSpreadsheetContext =
    SparkSpreadsheetService.SparkSpreadsheetContext(Some(serviceAccountId), new java.io.File(testCredentialPath))

  var spreadsheet: SparkSpreadsheet = null
  var worksheetName: String = ""

  def definedSchema: StructType = {
    new StructType()
      .add(new StructField("col_1", DataTypes.StringType))
      .add(new StructField("col_2", DataTypes.LongType))
      .add(new StructField("col_3", DataTypes.StringType))
  }

  case class Elem(col_1: String, col_2: Long, col_3: String)

  def extractor(e: Elem): RowData =
    new RowData().setValues(
      List(
        new CellData().setUserEnteredValue(
          new ExtendedValue().setStringValue(e.col_1)
        ),
        new CellData().setUserEnteredValue(
          new ExtendedValue().setNumberValue(e.col_2.toDouble)
        ),
        new CellData().setUserEnteredValue(
          new ExtendedValue().setStringValue(e.col_3)
        )
      ).asJava
    )

  before {
    spreadsheet = context.findSpreadsheet(TEST_SPREADSHEET_ID)
    worksheetName = scala.util.Random.alphanumeric.take(16).mkString
    val data = List(
      Elem("a", 1L, "x"),
      Elem("b", 2L, "y"),
      Elem("c", 3L, "z")
    )

    spreadsheet.addWorksheet(worksheetName, definedSchema, data, extractor)
  }

  after {
    spreadsheet.deleteWorksheet(worksheetName)
  }

  behavior of "A Spreadsheet"
  it should "find the new worksheet" in {
    val newWorksheet = spreadsheet.findWorksheet(worksheetName)
    assert(newWorksheet.isDefined)
    assert(newWorksheet.get.name == worksheetName)
    assert(newWorksheet.get.headers == Seq("col_1", "col_2", "col_3"))

    val rows = newWorksheet.get.rows
    assert(rows.head == Map("col_1" -> "a", "col_2" -> "1", "col_3" -> "x"))
  }

  behavior of "SparkWorksheet#updateCells"
  it should "update values in a worksheet" in {
    val newWorksheet = spreadsheet.findWorksheet(worksheetName)
    assert(newWorksheet.isDefined)

    val newData = List(
      Elem("f", 5L, "yy"),
      Elem("e", 4L, "xx"),
      Elem("c", 3L, "z"),
      Elem("b", 2L, "y"),
      Elem("a", 1L, "x")
    )

    newWorksheet.get.updateCells(definedSchema, newData, extractor)

    val rows = newWorksheet.get.rows
    assert(rows.head == Map("col_1" -> "f", "col_2" -> "5", "col_3" -> "yy"))
    assert(rows.last == Map("col_1" -> "a", "col_2" -> "1", "col_3" -> "x"))
  }
} 
Example 185
Source File: DataTypeUtil.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

object DataTypeUtil {
  def sameType(left: DataType, right: DataType): Boolean =
    if (SQLConf.get.caseSensitiveAnalysis) {
      equalsIgnoreNullability(left, right)
    } else {
      equalsIgnoreCaseAndNullability(left, right)
    }

  private def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
    (left, right) match {
      case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
        equalsIgnoreNullability(leftElementType, rightElementType)
      case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
        equalsIgnoreNullability(leftKeyType, rightKeyType) &&
          equalsIgnoreNullability(leftValueType, rightValueType)
      case (StructType(leftFields), StructType(rightFields)) =>
        leftFields.length == rightFields.length &&
          leftFields.zip(rightFields).forall { case (l, r) =>
            l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
          }
      case (l, r) => l == r
    }
  }

  private def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = {
    (from, to) match {
      case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
        equalsIgnoreCaseAndNullability(fromElement, toElement)

      case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
        equalsIgnoreCaseAndNullability(fromKey, toKey) &&
          equalsIgnoreCaseAndNullability(fromValue, toValue)

      case (StructType(fromFields), StructType(toFields)) =>
        fromFields.length == toFields.length &&
          fromFields.zip(toFields).forall { case (l, r) =>
            l.name.equalsIgnoreCase(r.name) &&
              equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
          }

      case (fromDataType, toDataType) => fromDataType == toDataType
    }
  }
} 
Example 186
Source File: SQLTransformer.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.feature

import com.tencent.angel.sona.ml.Transformer
import com.tencent.angel.sona.ml.param.{Param, ParamMap}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.StructType

/**
  * Implements the transformations which are defined by SQL statement.
  * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...'
  * where '__THIS__' represents the underlying table of the input dataset.
  * The select clause specifies the fields, constants, and expressions to display in
  * the output, it can be any select clause that Spark SQL supports. Users can also
  * use Spark SQL built-in function and UDFs to operate on these selected columns.
  * For example, [[SQLTransformer]] supports statements like:
  * {{{
  *  SELECT a, a + b AS a_b FROM __THIS__
  *  SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
  *  SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
  * }}}
  */
class SQLTransformer(override val uid: String) extends Transformer
  with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("sql"))

  /**
    * SQL statement parameter. The statement is provided in string form.
    *
    * @group param
    */
  final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")

  
  def setStatement(value: String): this.type = set(statement, value)

  
  def getStatement: String = $(statement)

  private val tableIdentifier: String = "__THIS__"

  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    val tableName = Identifiable.randomUID(uid)
    dataset.createOrReplaceTempView(tableName)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    val result = dataset.sparkSession.sql(realStatement)
    // Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset.
    dataset.sparkSession.catalog.dropTempView(tableName)
    // Compatible.sessionstate.catalog.dropTempView(tableName)
    result
  }

  override def transformSchema(schema: StructType): StructType = {
    val spark = SparkSession.builder().getOrCreate()
    val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty))
    val dummyDF = spark.createDataFrame(dummyRDD, schema)
    val tableName = Identifiable.randomUID(uid)
    val realStatement = $(statement).replace(tableIdentifier, tableName)
    dummyDF.createOrReplaceTempView(tableName)
    val outputSchema = spark.sql(realStatement).schema
    spark.catalog.dropTempView(tableName)
    outputSchema
  }

  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}


object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {
  override def load(path: String): SQLTransformer = super.load(path)
} 
Example 187
Source File: Correlation.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.stat

import org.apache.spark.linalg.{SQLDataTypes, Vector}

import scala.collection.JavaConverters._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{StructField, StructType}

/**
 * API for correlation functions in MLlib, compatible with DataFrames and Datasets.
 *
 * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]]
 * to spark.ml's Vector types.
 */
object Correlation {

  /**
   * :: Experimental ::
   * Compute the correlation matrix for the input Dataset of Vectors using the specified method.
   * Methods currently supported: `pearson` (default), `spearman`.
   *
   * @param dataset A dataset or a dataframe
   * @param column The name of the column of vectors for which the correlation coefficient needs
   *               to be computed. This must be a column of the dataset, and it must contain
   *               Vector objects.
   * @param method String specifying the method to use for computing correlation.
   *               Supported: `pearson` (default), `spearman`
   * @return A dataframe that contains the correlation matrix of the column of vectors. This
   *         dataframe contains a single row and a single column of name
   *         '$METHODNAME($COLUMN)'.
   * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
   *                                  the content of this column is not of type Vector.
   *
   *  Here is how to access the correlation coefficient:
   *  {{{
   *    val data: Dataset[Vector] = ...
   *    val Row(coeff: Matrix) = Correlation.corr(data, "value").head
   *    // coeff now contains the Pearson correlation matrix.
   *  }}}
   *
   * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
   * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
   * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"`
   * to avoid recomputing the common lineage.
   */

  def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
    val rdd = dataset.select(column).rdd.map {
      case Row(v: Vector) => v
    }
    val oldM = Statistics.corr(rdd, method)
    val name = s"$method($column)"
    val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
    dataset.sparkSession.createDataFrame(Seq(Row(oldM)).asJava, schema)
  }

  /**
   * Compute the Pearson correlation matrix for the input Dataset of Vectors.
   */

  def corr(dataset: Dataset[_], column: String): DataFrame = {
    corr(dataset, column, "pearson")
  }
} 
Example 188
Source File: Predictor.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.common
import com.tencent.angel.mlcore.conf.{MLCoreConf, SharedConf}
import com.tencent.angel.ml.math2.utils.{DataBlock, LabeledData}
import org.apache.spark.broadcast.Broadcast
import com.tencent.angel.sona.ml.common.MathImplicits._
import com.tencent.angel.sona.core.{AngelGraphModel, ExecutorContext}
import com.tencent.angel.sona.data.LocalMemoryDataBlock
import org.apache.spark.linalg
import org.apache.spark.linalg.Vectors
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SPKSQLUtils}

import scala.collection.mutable.ListBuffer

class Predictor(bcValue: Broadcast[ExecutorContext],
                featIdx: Int, predictionCol: String, probabilityCol: String,
                bcConf: Broadcast[SharedConf]) extends Serializable {

  @transient private lazy val executorContext: ExecutorContext = {
    bcValue.value
  }

  @transient private lazy implicit val dim: Long = {
    executorContext.conf.getLong(MLCoreConf.ML_FEATURE_INDEX_RANGE)
  }

  @transient private lazy val appendedSchema: StructType = if (probabilityCol.nonEmpty) {
    new StructType(Array[StructField](StructField(probabilityCol, DoubleType),
      StructField(predictionCol, DoubleType)))
  } else {
    new StructType(Array[StructField](StructField(predictionCol, DoubleType)))
  }

  def predictRDD(data: Iterator[Row]): Iterator[Row] = {
    val localModel = executorContext.borrowModel(bcConf.value)
    val batchSize = 1024
    val storage = new LocalMemoryDataBlock(batchSize, batchSize * 1024 * 1024)

    var count = 0
    val cachedRows: Array[Row] = new Array[Row](batchSize)
    val result: ListBuffer[Row] = ListBuffer[Row]()
    data.foreach {
      case row if count != 0 && count % batchSize == 0 =>
        predictInternal(localModel, storage, cachedRows, result)

        storage.clean()
        storage.put(new LabeledData(row.get(featIdx).asInstanceOf[linalg.Vector], 0.0))
        cachedRows(count % batchSize) = row
        count += 1
      case row =>
        storage.put(new LabeledData(row.get(featIdx).asInstanceOf[linalg.Vector], 0.0))
        cachedRows(count % batchSize) = row
        count += 1
    }

    predictInternal(localModel, storage, cachedRows, result)

    executorContext.returnModel(localModel)

    result.toIterator
  }

  private def predictInternal(model: AngelGraphModel,
                              storage: DataBlock[LabeledData],
                              cachedRows: Array[Row],
                              result: ListBuffer[Row]): Unit = {
    val predicted = model.predict(storage)

    if (appendedSchema.length == 1) {
      predicted.zipWithIndex.foreach {
        case (res, idx) =>
          result.append(SPKSQLUtils.append(cachedRows(idx), appendedSchema, res.pred))
      }
    } else {
      predicted.zipWithIndex.foreach {
        case (res, idx) =>
          result.append(SPKSQLUtils.append(cachedRows(idx), appendedSchema, res.proba, res.predLabel))
      }
    }

  }

  def predictRaw(features: linalg.Vector): linalg.Vector = {
    val localModel = executorContext.borrowModel(bcConf.value)

    val res = localModel.predict(new LabeledData(features, 0.0))

    executorContext.returnModel(localModel)
    Vectors.dense(res.pred, -res.pred)
  }
} 
Example 189
Source File: KCore.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.graph.kcore
import com.tencent.angel.sona.context.PSContext
import org.apache.spark.SparkContext
import com.tencent.angel.sona.graph.params._
import com.tencent.angel.sona.ml.Transformer
import com.tencent.angel.sona.ml.param.ParamMap
import com.tencent.angel.sona.ml.util.Identifiable
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.storage.StorageLevel

class KCore(override val uid: String) extends Transformer
  with HasSrcNodeIdCol with HasDstNodeIdCol with HasOutputNodeIdCol with HasOutputCoreIdCol
  with HasStorageLevel with HasPartitionNum with HasPSPartitionNum with HasUseBalancePartition {

  def this() = this(Identifiable.randomUID("KCore"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val edges = dataset.select($(srcNodeIdCol), $(dstNodeIdCol)).rdd
      .map(row => (row.getLong(0), row.getLong(1)))
      .filter(e => e._1 != e._2)

    edges.persist(StorageLevel.DISK_ONLY)

    val maxId = edges.map(e => math.max(e._1, e._2)).max() + 1
    val minId = edges.map(e => math.min(e._1, e._2)).min()
    val nodes = edges.flatMap(e => Iterator(e._1, e._2))
    val numEdges = edges.count()

    println(s"minId=$minId maxId=$maxId numEdges=$numEdges level=${$(storageLevel)}")

    // Start PS and init the model
    println("start to run ps")
    PSContext.getOrCreate(SparkContext.getOrCreate())

    val model = KCorePSModel.fromMinMax(minId, maxId, nodes, $(psPartitionNum), $(useBalancePartition))
    var graph = edges.flatMap(e => Iterator((e._1, e._2), (e._2, e._1)))
      .groupByKey($(partitionNum))
      .mapPartitionsWithIndex((index, edgeIter) =>
        Iterator(KCoreGraphPartition.apply(index, edgeIter)))

    graph.persist($(storageLevel))
    graph.foreachPartition(_ => Unit)
    graph.foreach(_.initMsgs(model))

    var curIteration = 0
    var numMsgs = model.numMsgs()
    var prev = graph
    println(s"numMsgs=$numMsgs")

    do {
      curIteration += 1
      graph = prev.map(_.process(model, numMsgs, curIteration == 1))
      graph.persist($(storageLevel))
      graph.count()
      prev.unpersist(true)
      prev = graph
      model.resetMsgs()
      numMsgs = model.numMsgs()
      println(s"curIteration=$curIteration numMsgs=$numMsgs")
    } while (numMsgs > 0)

    val retRDD = graph.map(_.save()).flatMap{case (nodes,cores) => nodes.zip(cores)}
      .map(r => Row.fromSeq(Seq[Any](r._1, r._2)))

    dataset.sparkSession.createDataFrame(retRDD, transformSchema(dataset.schema))
  }

  override def transformSchema(schema: StructType): StructType = {
    StructType(Seq(
      StructField(s"${$(outputNodeIdCol)}", LongType, nullable = false),
      StructField(s"${$(outputCoreIdCol)}", IntegerType, nullable = false)
    ))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

} 
Example 190
Source File: SQLTransformerSuite.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.feature

import com.tencent.angel.sona.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.storage.StorageLevel

class SQLTransformerSuite extends MLTest with DefaultReadWriteTest {

  import testImplicits._

  test("transform numeric data") {
    val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
    val sqlTrans = new SQLTransformer().setStatement(
      "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
     val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))
      .toDF("id", "v1", "v2", "v3", "v4")
    val resultSchema = sqlTrans.transformSchema(original.schema)
    testTransformerByGlobalCheckFunc[(Int, Double, Double)](
      original,
      sqlTrans,
      "id",
      "v1",
      "v2",
      "v3",
      "v4") { rows =>
      assert(rows.head.schema.toString == resultSchema.toString)
      assert(resultSchema == expected.schema)
      assert(rows == expected.collect().toSeq)
      assert(original.sparkSession.catalog.listTables().count() == 0)
    }
  }

  test("read/write") {
    val t = new SQLTransformer()
      .setStatement("select * from __THIS__")
    testDefaultReadWrite(t)
  }

  test("transformSchema") {
    val df = spark.range(10)
    val outputSchema = new SQLTransformer()
      .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
      .transformSchema(df.schema)
    val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
    assert(outputSchema === expected)
  }

  ignore("SPARK-22538: SQLTransformer should not unpersist given dataset") {
    val df = spark.range(10).toDF()
    df.cache()
    df.count()
    assert(df.storageLevel != StorageLevel.NONE)
    val sqlTrans = new SQLTransformer()
      .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
    testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => }
    assert(df.storageLevel != StorageLevel.NONE)
  }
} 
Example 191
Source File: SPKSQLUtils.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
import org.apache.spark.sql.types.{StructType, UDTRegistration}

object SPKSQLUtils {
  def append(row: Row, fields: StructType, values: Any*): Row = {
    row match {
      case r: GenericRowWithSchema =>
        val newValues = new Array[Any](r.length + values.length)
        val rLength: Int = r.length
        (0 until rLength).foreach(idx => newValues(idx) = r(idx))
        values.zipWithIndex.foreach { case (value, idx) =>
          newValues(idx + rLength) = value
        }

        val newSchema = if (r.schema != null) {
          val schemaTemp = StructType(r.schema)
          fields.foreach(field => schemaTemp.add(field))
          schemaTemp
        } else {
          null.asInstanceOf[StructType]
        }
        new GenericRowWithSchema(newValues, newSchema)
      case r: GenericRow =>
        val newValues = new Array[Any](r.length + values.length)
        val rLength: Int = r.length
        (0 until rLength).foreach(idx => newValues(idx) = r(idx))
        values.zipWithIndex.foreach { case (value, idx) =>
          newValues(idx + rLength) = value
        }

        new GenericRow(newValues)
      case _ =>
        throw new Exception("Row Error!")
    }
  }

  def registerUDT(): Unit = synchronized{
    UDTRegistration.register("org.apache.spark.linalg.Vector", "org.apache.spark.linalg.VectorUDT")
    UDTRegistration.register("org.apache.spark.linalg.DenseVector", "org.apache.spark.linalg.VectorUDT")
    UDTRegistration.register("org.apache.spark.linalg.SparseVector", "org.apache.spark.linalg.VectorUDT")
    UDTRegistration.register("org.apache.spark.linalg.Matrix", "org.apache.spark.linalg.MatrixUDT")
    UDTRegistration.register("org.apache.spark.linalg.DenseMatrix", "org.apache.spark.linalg.MatrixUDT")
    UDTRegistration.register("org.apache.spark.linalg.SparseMatrix", "org.apache.spark.linalg.MatrixUDT")
  }
} 
Example 192
Source File: VCFFileWriter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.vcf

import java.io.OutputStream

import htsjdk.samtools.ValidationStringency
import htsjdk.variant.vcf._
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types.StructType

import io.projectglow.common.GlowLogging

class VCFFileWriter(
    headerLineSet: Set[VCFHeaderLine],
    sampleIdInfo: SampleIdInfo,
    stringency: ValidationStringency,
    schema: StructType,
    conf: Configuration,
    stream: OutputStream,
    writeHeader: Boolean)
    extends OutputWriter
    with GlowLogging {

  private val converter =
    new InternalRowToVariantContextConverter(schema, headerLineSet, stringency)
  converter.validate()
  private val writer: VCFStreamWriter =
    new VCFStreamWriter(stream, headerLineSet, sampleIdInfo, writeHeader)

  override def write(row: InternalRow): Unit = {
    converter.convert(row).foreach(writer.write)
  }

  override def close(): Unit = {
    writer.close()
  }
} 
Example 193
Source File: AnnotationUtils.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.vcf

import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StringType, StructField, StructType}

// Unified VCF annotation representation, used by SnpEff and VEP
object AnnotationUtils {

  // Delimiter between annotation fields
  val annotationDelimiter = "|"
  val annotationDelimiterRegex = "\\|"

  // Fractional delimiter for struct subfields
  val structDelimiter = "/"
  val structDelimiterRegex = "\\/"

  // Delimiter for array subfields
  val arrayDelimiter = "&"

  // Struct subfield schemas
  private val rankTotalStruct = StructType(
    Seq(StructField("rank", IntegerType), StructField("total", IntegerType)))
  private val posLengthStruct = StructType(
    Seq(StructField("pos", IntegerType), StructField("length", IntegerType)))
  private val referenceVariantStruct = StructType(
    Seq(StructField("reference", StringType), StructField("variant", StringType)))

  // Special schemas for SnpEff subfields
  private val snpEffFieldsToSchema: Map[String, DataType] = Map(
    "Annotation" -> ArrayType(StringType),
    "Rank" -> rankTotalStruct,
    "cDNA_pos/cDNA_length" -> posLengthStruct,
    "CDS_pos/CDS_length" -> posLengthStruct,
    "AA_pos/AA_length" -> posLengthStruct,
    "Distance" -> IntegerType
  )

  // Special schemas for VEP subfields
  private val vepFieldsToSchema: Map[String, DataType] = Map(
    "Consequence" -> ArrayType(StringType),
    "EXON" -> rankTotalStruct,
    "INTRON" -> rankTotalStruct,
    "cDNA_position" -> IntegerType,
    "CDS_position" -> IntegerType,
    "Protein_position" -> IntegerType,
    "Amino_acids" -> referenceVariantStruct,
    "Codons" -> referenceVariantStruct,
    "Existing_variation" -> ArrayType(StringType),
    "DISTANCE" -> IntegerType,
    "STRAND" -> IntegerType,
    "FLAGS" -> ArrayType(StringType)
  )

  // Special schemas for LOFTEE (as VEP plugin) subfields
  private val lofteeFieldsToSchema: Map[String, DataType] = Map(
    "LoF_filter" -> ArrayType(StringType),
    "LoF_flags" -> ArrayType(StringType),
    "LoF_info" -> ArrayType(StringType)
  )

  // Default string schema for annotation subfield
  val allFieldsToSchema: Map[String, DataType] =
    (snpEffFieldsToSchema ++ vepFieldsToSchema ++ lofteeFieldsToSchema).withDefaultValue(StringType)
} 
Example 194
Source File: VCFWriterUtils.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.vcf

import htsjdk.variant.variantcontext.{VariantContext, VariantContextBuilder}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{ArrayType, StructType}

import io.projectglow.common.GlowLogging

object VCFWriterUtils extends GlowLogging {

  def throwMixedSamplesFailure(): Unit = {
    throw new IllegalArgumentException("Cannot mix missing and non-missing sample IDs.")
  }

  def throwSampleInferenceFailure(): Unit = {
    throw new IllegalArgumentException(
      "Cannot infer sample ids because they are not the same in every row.")
  }

  
  def inferSampleIdsIfPresent(data: DataFrame): SampleIdInfo = {
    val genotypeSchemaOpt = data
      .schema
      .find(_.name == "genotypes")
      .map(_.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType])
    if (genotypeSchemaOpt.isEmpty) {
      logger.info("No genotypes column, no sample IDs will be inferred.")
      return SampleIds(Seq.empty)
    }
    val genotypeSchema = genotypeSchemaOpt.get

    import data.sparkSession.implicits._
    val hasSampleIdsColumn = genotypeSchema.exists(_.name == "sampleId")

    if (hasSampleIdsColumn) {
      val distinctSampleIds = data
        .selectExpr("explode(genotypes.sampleId)")
        .distinct()
        .as[String]
        .collect
      val numPresentSampleIds = distinctSampleIds.count(!sampleIsMissing(_))

      if (numPresentSampleIds > 0) {
        if (numPresentSampleIds < distinctSampleIds.length) {
          throwMixedSamplesFailure()
        }
        return SampleIds(distinctSampleIds)
      }
    }

    val numGenotypesPerRow = data
      .selectExpr("size(genotypes)")
      .distinct()
      .as[Int]
      .collect

    if (numGenotypesPerRow.length > 1) {
      throw new IllegalArgumentException(
        "Rows contain varying number of missing samples; cannot infer sample IDs.")
    }

    logger.warn("Detected missing sample IDs, inferring sample IDs.")
    InferSampleIds
  }

  def sampleIsMissing(s: String): Boolean = {
    s == null || s.isEmpty
  }

  def convertVcAttributesToStrings(vc: VariantContext): VariantContextBuilder = {
    val vcBuilder = new VariantContextBuilder(vc)
    val iterator = vc.getAttributes.entrySet().iterator()
    while (iterator.hasNext) {
      // parse to string, then write, as the VCF encoder messes up double precisions
      val entry = iterator.next()
      vcBuilder.attribute(
        entry.getKey,
        VariantContextToInternalRowConverter.parseObjectAsString(entry.getValue))
    }
    vcBuilder
  }
}

case class SampleIds(unsortedSampleIds: Seq[String]) extends SampleIdInfo {
  val sortedSampleIds: Seq[String] = unsortedSampleIds.sorted
}
case object InferSampleIds extends SampleIdInfo {
  def fromNumberMissing(numMissingSamples: Int): Seq[String] = {
    (1 to numMissingSamples).map { idx =>
      "sample_" + idx
    }
  }
}

sealed trait SampleIdInfo 
Example 195
Source File: CSVOutputFormatter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.pipe

import java.io.InputStream

import scala.collection.JavaConverters._

import com.univocity.parsers.csv.CsvParser
import org.apache.commons.io.IOUtils
import org.apache.spark.sql.execution.datasources.csv.{CSVDataSourceUtils, CSVUtils, UnivocityParserUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import io.projectglow.SparkShim.{CSVOptions, UnivocityParser}

class CSVOutputFormatter(parsedOptions: CSVOptions) extends OutputFormatter {

  private def getSchema(record: Array[String]): StructType = {
    val header =
      CSVDataSourceUtils.makeSafeHeader(
        record,
        SQLConf.get.caseSensitiveAnalysis,
        parsedOptions
      )
    val fields = header.map { fieldName =>
      StructField(fieldName, StringType, nullable = true)
    }
    StructType(fields)
  }

  
  override def makeIterator(stream: InputStream): Iterator[Any] = {
    val lines = IOUtils.lineIterator(stream, "UTF-8").asScala
    val filteredLines = CSVUtils.filterCommentAndEmpty(lines, parsedOptions)

    if (filteredLines.isEmpty) {
      return Iterator.empty
    }

    val firstLine = filteredLines.next
    val csvParser = new CsvParser(parsedOptions.asParserSettings)
    val firstRecord = csvParser.parseLine(firstLine)
    val schema = getSchema(firstRecord)
    val univocityParser = new UnivocityParser(schema, schema, parsedOptions)

    val parsedIter =
      UnivocityParserUtils.parseIterator(
        Iterator(firstLine) ++ filteredLines,
        univocityParser,
        schema
      )

    val parsedIterWithoutHeader = if (parsedOptions.headerFlag) {
      parsedIter.drop(1)
    } else {
      parsedIter
    }

    Iterator(schema) ++ parsedIterWithoutHeader.map(_.copy)
  }
}

class CSVOutputFormatterFactory extends OutputFormatterFactory {
  override def name: String = "csv"

  override def makeOutputFormatter(
      options: Map[String, String]
  ): OutputFormatter = {
    val parsedOptions =
      new CSVOptions(
        options,
        SQLConf.get.csvColumnPruning,
        SQLConf.get.sessionLocalTimeZone
      )
    new CSVOutputFormatter(parsedOptions)
  }
} 
Example 196
Source File: UTF8TextOutputFormatter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.pipe

import java.io.InputStream

import scala.collection.JavaConverters._

import org.apache.commons.io.IOUtils
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String


class UTF8TextOutputFormatter() extends OutputFormatter {

  override def makeIterator(stream: InputStream): Iterator[Any] = {
    val schema = StructType(Seq(StructField("text", StringType)))
    val iter = IOUtils.lineIterator(stream, "UTF-8").asScala.map { s =>
      new GenericInternalRow(Array(UTF8String.fromString(s)): Array[Any])
    }
    Iterator(schema) ++ iter
  }
}

class UTF8TextOutputFormatterFactory extends OutputFormatterFactory {
  override def name: String = "text"

  override def makeOutputFormatter(options: Map[String, String]): OutputFormatter = {
    new UTF8TextOutputFormatter
  }
} 
Example 197
Source File: CSVInputFormatter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.pipe

import java.io.{OutputStream, PrintWriter}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.csv.SGUnivocityGenerator
import org.apache.spark.sql.types.StructType

import io.projectglow.SparkShim.CSVOptions

class CSVInputFormatter(schema: StructType, parsedOptions: CSVOptions) extends InputFormatter {

  private var writer: PrintWriter = _
  private var univocityGenerator: SGUnivocityGenerator = _

  override def init(stream: OutputStream): Unit = {
    writer = new PrintWriter(stream)
    univocityGenerator = new SGUnivocityGenerator(schema, writer, parsedOptions)
    if (parsedOptions.headerFlag) {
      univocityGenerator.writeHeaders()
    }
  }

  override def write(record: InternalRow): Unit = {
    univocityGenerator.write(record)
  }

  override def close(): Unit = {
    writer.close()
    univocityGenerator.close()
  }
}

class CSVInputFormatterFactory extends InputFormatterFactory {
  override def name: String = "csv"

  override def makeInputFormatter(
      df: DataFrame,
      options: Map[String, String]
  ): InputFormatter = {
    val sqlConf = df.sparkSession.sessionState.conf
    val parsedOptions =
      new CSVOptions(
        options,
        sqlConf.csvColumnPruning,
        sqlConf.sessionLocalTimeZone
      )
    new CSVInputFormatter(df.schema, parsedOptions)
  }
} 
Example 198
Source File: PlinkRowToInternalRowConverter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.plink

import org.apache.spark.sql.SQLUtils.structFieldsEqualExceptNullability
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.RowConverter


class PlinkRowToInternalRowConverter(schema: StructType) extends GlowLogging {

  private val homAlt = new GenericArrayData(Array(1, 1))
  private val missing = new GenericArrayData(Array(-1, -1))
  private val het = new GenericArrayData(Array(0, 1))
  private val homRef = new GenericArrayData(Array(0, 0))

  private def twoBitsToCalls(twoBits: Int): GenericArrayData = {
    twoBits match {
      case 0 => homAlt // Homozygous for first (alternate) allele
      case 1 => missing // Missing genotype
      case 2 => het // Heterozygous
      case 3 => homRef // Homozygous for second (reference) allele
    }
  }

  private val converter = {
    val fns = schema.map { field =>
      val fn: RowConverter.Updater[(Array[UTF8String], Array[Byte])] = field match {
        case f if f.name == VariantSchemas.genotypesFieldName =>
          val gSchema = f.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
          val converter = makeGenotypeConverter(gSchema)
          (samplesAndBlock, r, i) => {
            val genotypes = new Array[Any](samplesAndBlock._1.length)
            var sampleIdx = 0
            while (sampleIdx < genotypes.length) {
              val sample = samplesAndBlock._1(sampleIdx)
              // Get the relevant 2 bits for the sample from the block
              // The i-th sample's call bits are the (i%4)-th pair within the (i/4)-th block
              val twoBits = samplesAndBlock._2(sampleIdx / 4) >> (2 * (sampleIdx % 4)) & 3
              genotypes(sampleIdx) = converter((sample, twoBits))
              sampleIdx += 1
            }
            r.update(i, new GenericArrayData(genotypes))
          }
        case _ =>
          // BED file only contains genotypes
          (_, _, _) => ()
      }
      fn
    }
    new RowConverter[(Array[UTF8String], Array[Byte])](schema, fns.toArray)
  }

  private def makeGenotypeConverter(gSchema: StructType): RowConverter[(UTF8String, Int)] = {
    val functions = gSchema.map { field =>
      val fn: RowConverter.Updater[(UTF8String, Int)] = field match {
        case f if structFieldsEqualExceptNullability(f, VariantSchemas.sampleIdField) =>
          (sampleAndTwoBits, r, i) => {
            r.update(i, sampleAndTwoBits._1)
          }
        case f if structFieldsEqualExceptNullability(f, VariantSchemas.callsField) =>
          (sampleAndTwoBits, r, i) => r.update(i, twoBitsToCalls(sampleAndTwoBits._2))
        case f =>
          logger.info(
            s"Genotype field $f cannot be derived from PLINK files. It will be null " +
            s"for each sample."
          )
          (_, _, _) => ()
      }
      fn
    }
    new RowConverter[(UTF8String, Int)](gSchema, functions.toArray)
  }

  def convertRow(
      bimRow: InternalRow,
      sampleIds: Array[UTF8String],
      gtBlock: Array[Byte]): InternalRow = {
    converter((sampleIds, gtBlock), bimRow)
  }
} 
Example 199
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 200
Source File: MomentAggState.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

import io.projectglow.common.GlowLogging


  def toInternalRow(row: InternalRow, offset: Int = 0): InternalRow = {
    row.update(offset, if (count > 0) mean else null)
    row.update(offset + 1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null)
    row.update(offset + 2, if (count > 0) min else null)
    row.update(offset + 3, if (count > 0) max else null)
    row
  }

  def toInternalRow: InternalRow = {
    toInternalRow(new GenericInternalRow(4))
  }
}

object MomentAggState extends GlowLogging {
  val schema = StructType(
    Seq(
      StructField("mean", DoubleType),
      StructField("stdDev", DoubleType),
      StructField("min", DoubleType),
      StructField("max", DoubleType)
    )
  )

  def merge(s1: MomentAggState, s2: MomentAggState): MomentAggState = {
    if (s1.count == 0) {
      return s2
    } else if (s2.count == 0) {
      return s1
    }

    val newState = MomentAggState()
    newState.count = s1.count + s2.count
    val delta = s2.mean - s1.mean
    val deltaN = delta / newState.count
    newState.mean = s1.mean + deltaN * s2.count

    // higher order moments computed according to:
    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
    newState.m2 = s1.m2 + s2.m2 + delta * deltaN * s1.count * s2.count

    newState.min = Math.min(s1.min, s2.min)
    newState.max = Math.max(s1.max, s2.max)
    newState
  }
}