org.apache.spark.sql.catalyst.encoders.RowEncoder Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.encoders.RowEncoder. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: HTTPTransformer.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.http

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.io.http.HandlingUtils.HandlerFunc
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration

trait HasHandler extends Params {
  val handler: UDFParam = new UDFParam(
    this, "handler", "Which strategy to use when handling requests")

  
  override def transform(dataset: Dataset[_]): DataFrame = {
    val df = dataset.toDF()
    val enc = RowEncoder(transformSchema(df.schema))
    val colIndex = df.schema.fieldNames.indexOf(getInputCol)
    val fromRow = HTTPRequestData.makeFromRowConverter
    val toRow = HTTPResponseData.makeToRowConverter
    df.mapPartitions { it =>
      if (!it.hasNext) {
        Iterator()
      }else{
        val c = clientHolder.get
        val responsesWithContext = c.sendRequestsWithContext(it.map{row =>
          c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row))
        })
        responsesWithContext.map { rwc =>
          Row.merge(rwc.context.get.asInstanceOf[Row], Row(rwc.response.flatMap(Option(_)).map(toRow).orNull))
        }
      }
    }(enc)
  }

  def copy(extra: ParamMap): HTTPTransformer = defaultCopy(extra)

  def transformSchema(schema: StructType): StructType = {
    assert(schema(getInputCol).dataType == HTTPSchema.Request)
    schema.add(getOutputCol, HTTPSchema.Response, nullable=true)
  }

} 
Example 2
Source File: DataFrameCreation.scala    From couchbase-spark-connector   with Apache License 2.0 5 votes vote down vote up
// Putting this in this Spark package to be able to access internalCreateDataFrame
// See my (currently unanswered) SO post for context:
// https://stackoverflow.com/questions/56183811/how-to-create-a-custom-structured-streaming-source-for-apache-spark-2-3-0
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType


object DataFrameCreation {
  def createStreamingDataFrame(sqlContext: SQLContext,
                                rdd: RDD[Row],
                                schema: StructType): DataFrame = {
    // internalCreateDataFrame requires an RDD[InternalRow]
    val encoder = RowEncoder.apply(schema)
    val encoded: RDD[InternalRow] = rdd.map(row => {
      encoder.toRow(row)
    })

    sqlContext.internalCreateDataFrame(encoded, schema, isStreaming = true)
  }

  def createStreamingDataFrame(sqlContext: SQLContext,
                               df: DataFrame,
                               schema: StructType): DataFrame = {
    // internalCreateDataFrame requires an RDD[InternalRow]
    val encoder = RowEncoder.apply(schema)
    val encoded: RDD[InternalRow] = df.rdd.map(row => {
      encoder.toRow(row)
    })
    sqlContext.internalCreateDataFrame(encoded, schema, isStreaming = true)
  }
} 
Example 3
Source File: SqlUtils.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.StructType

object SqlUtils {
  def convertToDF(sparkSession: SparkSession, plan : LogicalPlan): DataFrame = {
    Dataset.ofRows(sparkSession, plan)
  }

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: LogicalPlan, failIfUnresolved: Boolean,
                        exprName: Option[String] = None): Expression = {
    resolveReferences(sparkSession, expr, Seq(planContaining), failIfUnresolved, exprName)
  }

  def resolveReferences(sparkSession: SparkSession,
                        expr: Expression,
                        planContaining: Seq[LogicalPlan],
                        failIfUnresolved: Boolean,
                        exprName: Option[String]): Expression = {
    val newPlan = FakeLogicalPlan(expr, planContaining)
    val resolvedExpr = sparkSession.sessionState.analyzer.execute(newPlan) match {
      case FakeLogicalPlan(resolvedExpr: Expression, _) =>
        // Return even if it did not successfully resolve
        resolvedExpr
      case _ =>
        expr
      // This is unexpected
    }
    if (failIfUnresolved) {
      resolvedExpr.flatMap(_.references).filter(!_.resolved).foreach {
        attr => {
          val failedMsg = exprName match {
            case Some(name) => s"${attr.sql} resolution in $name given these columns: "+
              planContaining.flatMap(_.output).map(_.name).mkString(",")
            case _ => s"${attr.sql} resolution failed given these columns: "+
              planContaining.flatMap(_.output).map(_.name).mkString(",")
          }
          attr.failAnalysis(failedMsg)
        }
      }
    }
    resolvedExpr
  }

  def hasSparkStopped(sparkSession: SparkSession): Boolean = {
    sparkSession.sparkContext.stopped.get()
  }

  
  def createDataFrameUsingAttributes(sparkSession: SparkSession,
                                     rdd: RDD[Row],
                                     schema: StructType,
                                     attributes: Seq[Attribute]): DataFrame = {
    val encoder = RowEncoder(schema)
    val catalystRows = rdd.map(encoder.toRow)
    val logicalPlan = LogicalRDD(
      attributes,
      catalystRows,
      isStreaming = false)(sparkSession)
    Dataset.ofRows(sparkSession, logicalPlan)
  }

  def analysisException(cause: String): Throwable = {
    new AnalysisException(cause)
  }
}

case class FakeLogicalPlan(expr: Expression, children: Seq[LogicalPlan])
  extends LogicalPlan {
  override def output: Seq[Attribute] = children.foldLeft(Seq[Attribute]())((out, child) => out ++ child.output)
} 
Example 4
Source File: SparkStreamletContextImpl.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.kafka

import java.io.File

import com.typesafe.config.Config
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.{ ExpressionEncoder, RowEncoder }
import org.apache.spark.sql.streaming.{ OutputMode, StreamingQuery }
import cloudflow.spark.SparkStreamletContext
import cloudflow.spark.avro.{ SparkAvroDecoder, SparkAvroEncoder }
import cloudflow.spark.sql.SQLImplicits._
import cloudflow.streamlets._

import scala.reflect.runtime.universe._

class SparkStreamletContextImpl(
    private[cloudflow] override val streamletDefinition: StreamletDefinition,
    session: SparkSession,
    override val config: Config
) extends SparkStreamletContext(streamletDefinition, session) {

  val storageDir           = config.getString("storage.mountPath")
  val maxOffsetsPerTrigger = config.getLong("cloudflow.spark.read.options.max-offsets-per-trigger")
  def readStream[In](inPort: CodecInlet[In])(implicit encoder: Encoder[In], typeTag: TypeTag[In]): Dataset[In] = {

    implicit val inRowEncoder: ExpressionEncoder[Row] = RowEncoder(encoder.schema)
    val schema                                        = inPort.schemaAsString
    val topic                                         = findTopicForPort(inPort)
    val srcTopic                                      = topic.name
    val brokers                                       = topic.bootstrapServers.getOrElse(internalKafkaBootstrapServers)
    val src: DataFrame = session.readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", brokers)
      .options(kafkaConsumerMap(topic))
      .option("maxOffsetsPerTrigger", maxOffsetsPerTrigger)
      .option("subscribe", srcTopic)
      // Allow restart of stateful streamlets that may have been offline for longer than the kafka retention period.
      // This setting may result in data loss in some cases but allows for continuity of the runtime
      .option("failOnDataLoss", false)
      .option("startingOffsets", "earliest")
      .load()

    val rawDataset = src.select($"value").as[Array[Byte]]

    val dataframe: Dataset[Row] = rawDataset.mapPartitions { iter ⇒
      val avroDecoder = new SparkAvroDecoder[In](schema)
      iter.map(avroDecoder.decode)
    }(inRowEncoder)

    dataframe.as[In]
  }

  def kafkaConsumerMap(topic: Topic) = topic.kafkaConsumerProperties.map {
    case (key, value) => s"kafka.$key" -> value
  }
  def kafkaProducerMap(topic: Topic) = topic.kafkaProducerProperties.map {
    case (key, value) => s"kafka.$key" -> value
  }

  def writeStream[Out](stream: Dataset[Out], outPort: CodecOutlet[Out], outputMode: OutputMode)(implicit encoder: Encoder[Out],
                                                                                                typeTag: TypeTag[Out]): StreamingQuery = {

    val avroEncoder   = new SparkAvroEncoder[Out](outPort.schemaAsString)
    val encodedStream = avroEncoder.encodeWithKey(stream, outPort.partitioner)

    val topic     = findTopicForPort(outPort)
    val destTopic = topic.name
    val brokers   = topic.bootstrapServers.getOrElse(internalKafkaBootstrapServers)

    // metadata checkpoint directory on mount
    val checkpointLocation = checkpointDir(outPort.name)
    val queryName          = s"$streamletRef.$outPort"

    encodedStream.writeStream
      .outputMode(outputMode)
      .format("kafka")
      .queryName(queryName)
      .option("kafka.bootstrap.servers", brokers)
      .options(kafkaProducerMap(topic))
      .option("topic", destTopic)
      .option("checkpointLocation", checkpointLocation)
      .start()
  }

  def checkpointDir(dirName: String): String = {
    val baseCheckpointDir = new File(storageDir, streamletRef)
    val dir               = new File(baseCheckpointDir, dirName)
    if (!dir.exists()) {
      val created = dir.mkdirs()
      require(created, s"Could not create checkpoint directory: $dir")
    }
    dir.getAbsolutePath
  }
} 
Example 5
Source File: SparkAvroDecoder.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.spark.avro

import org.apache.log4j.Logger

import java.io.ByteArrayOutputStream

import scala.reflect.runtime.universe._

import org.apache.avro.generic.{ GenericDatumReader, GenericDatumWriter, GenericRecord }
import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.spark.sql.{ Dataset, Encoder, Row }
import org.apache.spark.sql.catalyst.encoders.{ encoderFor, ExpressionEncoder, RowEncoder }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
import org.apache.avro.Schema

import cloudflow.spark.sql.SQLImplicits._

case class EncodedKV(key: String, value: Array[Byte])

case class SparkAvroDecoder[T: Encoder: TypeTag](avroSchema: String) {

  val encoder: Encoder[T]                           = implicitly[Encoder[T]]
  val sqlSchema: StructType                         = encoder.schema
  val encoderForDataColumns: ExpressionEncoder[Row] = RowEncoder(sqlSchema)
  @transient lazy val _avroSchema                   = new Schema.Parser().parse(avroSchema)
  @transient lazy val rowConverter                  = SchemaConverters.createConverterToSQL(_avroSchema, sqlSchema)
  @transient lazy val datumReader                   = new GenericDatumReader[GenericRecord](_avroSchema)
  @transient lazy val decoder                       = DecoderFactory.get
  def decode(bytes: Array[Byte]): Row = {
    val binaryDecoder = decoder.binaryDecoder(bytes, null)
    val record        = datumReader.read(null, binaryDecoder)
    rowConverter(record).asInstanceOf[GenericRow]
  }

}


case class SparkAvroEncoder[T: Encoder: TypeTag](avroSchema: String) {

  @transient lazy val log = Logger.getLogger(getClass.getName)

  val BufferSize = 5 * 1024 // 5 Kb

  val encoder                     = implicitly[Encoder[T]]
  val sqlSchema                   = encoder.schema
  @transient lazy val _avroSchema = new Schema.Parser().parse(avroSchema)

  val recordName                = "topLevelRecord" // ???
  val recordNamespace           = "recordNamespace" // ???
  @transient lazy val converter = AvroConverter.createConverterToAvro(sqlSchema, recordName, recordNamespace)

  // Risk: This process is memory intensive. Might require thread-level buffers to optimize memory usage
  def rowToBytes(row: Row): Array[Byte] = {
    val genRecord = converter(row).asInstanceOf[GenericRecord]
    if (log.isDebugEnabled) log.debug(s"genRecord = $genRecord")
    val datumWriter   = new GenericDatumWriter[GenericRecord](_avroSchema)
    val avroEncoder   = EncoderFactory.get
    val byteArrOS     = new ByteArrayOutputStream(BufferSize)
    val binaryEncoder = avroEncoder.binaryEncoder(byteArrOS, null)
    datumWriter.write(genRecord, binaryEncoder)
    binaryEncoder.flush()
    byteArrOS.toByteArray
  }

  def encode(dataset: Dataset[T]): Dataset[Array[Byte]] =
    dataset.toDF().mapPartitions(rows ⇒ rows.map(rowToBytes)).as[Array[Byte]]

  // Note to self: I'm not sure how heavy this chain of transformations is
  def encodeWithKey(dataset: Dataset[T], keyFun: T ⇒ String): Dataset[EncodedKV] = {
    val encoder             = encoderFor[T]
    implicit val rowEncoder = RowEncoder(encoder.schema).resolveAndBind()
    dataset.map { value ⇒
      val key         = keyFun(value)
      val internalRow = encoder.toRow(value)
      val row         = rowEncoder.fromRow(internalRow)
      val bytes       = rowToBytes(row)
      EncodedKV(key, bytes)
    }
  }

} 
Example 6
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 7
Source File: GroupedIteratorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 8
Source File: DataSourceV2ScanExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType


case class DataSourceV2ScanExec(
    output: Seq[AttributeReference],
    @transient reader: DataSourceReader)
  extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {

  override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

  override def outputPartitioning: physical.Partitioning = reader match {
    case s: SupportsReportPartitioning =>
      new DataSourcePartitioning(
        s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))

    case _ => super.outputPartitioning
  }

  private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
    case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
    case _ =>
      reader.createDataReaderFactories().asScala.map {
        new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
      }.asJava
  }

  private lazy val inputRDD: RDD[InternalRow] = reader match {
    case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
      assert(!reader.isInstanceOf[ContinuousReader],
        "continuous stream reader does not support columnar read yet.")
      new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
        .asInstanceOf[RDD[InternalRow]]

    case _: ContinuousReader =>
      EpochCoordinatorRef.get(
          sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
          sparkContext.env)
        .askSync[Unit](SetReaderPartitions(readerFactories.size()))
      new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
        .asInstanceOf[RDD[InternalRow]]

    case _ =>
      new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
  }

  override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)

  override val supportsBatch: Boolean = reader match {
    case r: SupportsScanColumnarBatch if r.enableBatchRead() => true
    case _ => false
  }

  override protected def needsUnsafeRowConversion: Boolean = false

  override protected def doExecute(): RDD[InternalRow] = {
    if (supportsBatch) {
      WholeStageCodegenExec(this)(codegenStageId = 0).execute()
    } else {
      val numOutputRows = longMetric("numOutputRows")
      inputRDD.map { r =>
        numOutputRows += 1
        r
      }
    }
  }
}

class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
  extends DataReaderFactory[UnsafeRow] {

  override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations

  override def createDataReader: DataReader[UnsafeRow] = {
    new RowToUnsafeDataReader(
      rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
  }
}

class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
  extends DataReader[UnsafeRow] {

  override def next: Boolean = rowReader.next

  override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow]

  override def close(): Unit = rowReader.close()
} 
Example 9
Source File: GroupedIteratorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 10
Source File: PowerBiSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.split1

import java.io.File

import com.microsoft.ml.spark.Secrets
import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.io.powerbi.PowerBIWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.{current_timestamp, lit}

import scala.collection.JavaConverters._

class PowerBiSuite extends TestBase with FileReaderUtils {

  lazy val url: String = sys.env.getOrElse("MML_POWERBI_URL", Secrets.PowerbiURL)
  lazy val df: DataFrame = session
    .createDataFrame(Seq(
      (Some(0), "a"),
      (Some(1), "b"),
      (Some(2), "c"),
      (Some(3), ""),
      (None, "bad_row")))
    .toDF("bar", "foo")
    .withColumn("baz", current_timestamp())
  lazy val bigdf: DataFrame = (1 to 5).foldRight(df) { case (_, ldf) => ldf.union(df) }.repartition(2)
  lazy val delayDF: DataFrame = {
    val rows = Array.fill(100){df.collect()}.flatten.toList.asJava
    val df2 = session
      .createDataFrame(rows, df.schema)
      .coalesce(1).cache()
    df2.count()
    df2.map({x => Thread.sleep(10); x})(RowEncoder(df2.schema))
  }

  test("write to powerBi", TestBase.BuildServer) {
    PowerBIWriter.write(df, url)
  }

  test("write to powerBi with delays"){
    PowerBIWriter.write(delayDF, url)
  }

  test("using dynamic minibatching"){
    PowerBIWriter.write(delayDF, url, Map("minibatcher"->"dynamic", "maxBatchSize"->"50"))
  }

  test("using timed minibatching"){
    PowerBIWriter.write(delayDF, url, Map("minibatcher"->"timed"))
  }

  test("using consolidated timed minibatching"){
    PowerBIWriter.write(delayDF, url, Map(
      "minibatcher"->"timed",
      "consolidate"->"true"))
  }

  test("using buffered batching"){
    PowerBIWriter.write(delayDF, url, Map("buffered"->"true"))
  }

  ignore("throw useful error message when given an improper dataset") {
    //TODO figure out why this does not throw errors on the build machine
    assertThrows[SparkException] {
      PowerBIWriter.write(df.withColumn("bad", lit("foo")), url)
    }
  }

  test("stream to powerBi", TestBase.BuildServer) {
    bigdf.write.parquet(tmpDir + File.separator + "powerBI.parquet")
    val sdf = session.readStream.schema(df.schema).parquet(tmpDir + File.separator + "powerBI.parquet")
    val q1 = PowerBIWriter.stream(sdf, url).start()
    q1.processAllAvailable()
  }

} 
Example 11
Source File: StratifiedRepartitionSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

  import session.implicits._

  val values = "values"
  val colors = "colors"
  val const = "const"

  lazy val input = Seq(
    (0, "Blue", 2),
    (0, "Red", 2),
    (0, "Green", 2),
    (1, "Purple", 2),
    (1, "Orange", 2),
    (1, "Indigo", 2),
    (2, "Violet", 2),
    (2, "Black", 2),
    (2, "White", 2),
    (3, "Gray", 2),
    (3, "Yellow", 2),
    (3, "Cerulean", 2)
  ).toDF(values, colors, const)

  test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
    val inputSchema = new StructType()
      .add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
    val inputEnc = RowEncoder(inputSchema)
    val valuesFieldIndex = inputSchema.fieldIndex(values)
    val numPartitions = 3
    val trainData = input.repartition(numPartitions).select(values, colors, const)
      .mapPartitions(iter => {
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        // Remove all instances of 0 class on partition 1
        if (partId == 1) {
          iter.flatMap(row => {
            if (row.getInt(valuesFieldIndex) <= 0)
              None
            else Some(row)
          })
        } else {
          // Add back at least 3 instances on other partitions
          val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
          (iter.toList.union(oneOfEachExample).union(oneOfEachExample).union(oneOfEachExample)).toIterator
        }
      })(inputEnc).cache()
    // Some debug to understand what data is on which partition
    trainData.foreachPartition { rows =>
      rows.foreach { row =>
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        println(s"Row: $row partition id: $partId")
      }
    }
    val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Equal).transform(trainData)
    // Assert stratified data contains all keys across all partitions, with extra count
    // for it to be evaluated
    stratifiedInputData
      .mapPartitions(iter => {
        val actualLabels = iter.map(row => row.getInt(valuesFieldIndex))
          .toArray.distinct.sorted.toList
        val expectedLabels = (0 to 3).toList
        if (actualLabels != expectedLabels)
          throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
        iter
      })(inputEnc).count()
    val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Mixed).transform(trainData)
    assert(stratifiedMixedInputData.count() >= trainData.count())
    val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Original).transform(trainData)
    assert(stratifiedOriginalInputData.count() == trainData.count())
  }

  def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
    new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

  def reader: MLReadable[_] = StratifiedRepartition
} 
Example 12
Source File: PartitionConsolidatorSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.flaky

import com.microsoft.ml.spark.core.test.base.TimeLimitedFlaky
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.ml.spark.io.http.PartitionConsolidator
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.scalatest.Assertion

class PartitionConsolidatorSuite extends TransformerFuzzing[PartitionConsolidator] with TimeLimitedFlaky {

  import session.implicits._

  override val numCores: Option[Int] = Some(2)

  lazy val df: DataFrame = (1 to 1000).toDF("values")

  override val sortInDataframeEquality: Boolean = true

  override def testObjects(): Seq[TestObject[PartitionConsolidator]] = Seq(
    new TestObject(new PartitionConsolidator(), df))

  override def reader: MLReadable[_] = PartitionConsolidator

  def getPartitionDist(df: DataFrame): List[Int] = {
    df.rdd.mapPartitions(it => Iterator(it.length)).collect().toList
  }

  //TODO figure out what is causing the issue on the build server
  override def testSerialization(): Unit = {}

  override def testExperiments(): Unit = {}

  def basicTest(df: DataFrame): Assertion = {
    val pd1 = getPartitionDist(df)
    val newDF = new PartitionConsolidator().transform(df)
    val pd2 = getPartitionDist(newDF)
    assert(pd1.sum === pd2.sum)
    assert(pd2.max >= pd1.max)
    assert(pd1.length === pd2.length)
  }

  test("basic functionality") {
    basicTest(df)
  }

  test("works with more partitions than cores") {
    basicTest(df.repartition(12))
  }

  test("overheads") {
    val baseDF = (1 to 1000).toDF("values").cache()
    println(baseDF.count())

    def getDF: Dataset[Row] = baseDF.map { x => Thread.sleep(10); x }(
      RowEncoder(new StructType().add("values", DoubleType)))

    val t1 = getTime(3)(
      getDF.foreach(_ => ()))._2
    val t2 = getTime(3)(
      new PartitionConsolidator().transform(getDF).foreach(_ => ()))._2

    println(t2.toDouble / t1.toDouble)
    assert(t2.toDouble / t1.toDouble < 3.0)
  }

  test("works with more partitions than cores2") {
    basicTest(df.repartition(100))
  }

  test("work with 1 partition") {
    basicTest(df.repartition(1))
  }

} 
Example 13
Source File: BinaryFileReader.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import com.microsoft.ml.spark.core.env.StreamUtilities
import com.microsoft.ml.spark.core.schema.BinaryFileSchema
import com.microsoft.ml.spark.core.utils.AsyncUtils
import org.apache.commons.io.IOUtils
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.spark.binary.BinaryFileFormat
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.binary.ConfUtils
import org.apache.spark.sql.types.BinaryType

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

object BinaryFileReader {

  private def recursePath(fileSystem: FileSystem,
                          path: Path,
                          pathFilter: FileStatus => Boolean,
                          visitedSymlinks: Set[Path]): Array[Path] ={
    val filteredPaths = fileSystem.listStatus(path).filter(pathFilter)
    val filteredDirs = filteredPaths.filter(fs => fs.isDirectory & !visitedSymlinks(fs.getPath))
    val symlinksFound = visitedSymlinks ++ filteredDirs.filter(_.isSymlink).map(_.getPath)
    filteredPaths.map(_.getPath) ++ filteredDirs.map(_.getPath)
      .flatMap(p => recursePath(fileSystem, p, pathFilter, symlinksFound))
  }

  def recursePath(fileSystem: FileSystem, path: Path, pathFilter: FileStatus => Boolean): Array[Path] ={
    recursePath(fileSystem, path, pathFilter, Set())
  }

  
  def readFromPaths(df: DataFrame,
                    pathCol: String,
                    bytesCol: String,
                    concurrency: Int,
                    timeout: Int
                   ): DataFrame = {
    val outputSchema = df.schema.add(bytesCol, BinaryType, nullable = true)
    val encoder = RowEncoder(outputSchema)
    val hconf = ConfUtils.getHConf(df)

    df.mapPartitions { rows =>
      val futures = rows.map {row: Row =>
        Future {
            val path = new Path(row.getAs[String](pathCol))
            val fs = path.getFileSystem(hconf.value)
            val bytes = StreamUtilities.using(fs.open(path)) {is => IOUtils.toByteArray(is)}.get
            val ret = Row.merge(Seq(row, Row(bytes)): _*)
            ret
          }(ExecutionContext.global)
      }
      AsyncUtils.bufferedAwait(
        futures,concurrency, Duration.fromNanos(timeout*(20^6).toLong))(ExecutionContext.global)
    }(encoder)
  }

} 
Example 14
Source File: GroupedIteratorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 15
Source File: SparkBindings.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.schema

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.types.StructType

import scala.reflect.runtime.universe.TypeTag

abstract class SparkBindings[T: TypeTag] extends Serializable {

  lazy val schema: StructType = enc.schema
  private lazy val enc: ExpressionEncoder[T] = ExpressionEncoder[T]().resolveAndBind()
  private lazy val rowEnc: ExpressionEncoder[Row] = RowEncoder(enc.schema).resolveAndBind()

  // WARNING: each time you use this function on a dataframe, you should make a new converter.
  // Spark does some magic that makes this leak memory if re-used on a
  // different symbolic node of the parallel computation. That being said,
  // you should make a single converter before using it in a udf so
  // that the slow resolving and binding is not in the hotpath
  def makeFromRowConverter: Row => T = {
    val enc1 = enc.resolveAndBind()
    val rowEnc1 = rowEnc.resolveAndBind();
    { r: Row => enc1.fromRow(rowEnc1.toRow(r)) }
  }

  def makeFromInternalRowConverter: InternalRow => T = {
    val enc1 = enc.resolveAndBind();
    { r: InternalRow => enc1.fromRow(r) }
  }

  def makeToRowConverter: T => Row = {
    val enc1 = enc.resolveAndBind()
    val rowEnc1 = rowEnc.resolveAndBind();
    { v: T => rowEnc1.fromRow(enc1.toRow(v)) }
  }

  def makeToInternalRowConverter: T => InternalRow = {
    val enc1 = enc.resolveAndBind();
    { v: T => enc1.toRow(v) }
  }

} 
Example 16
Source File: MultiNGram.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.featurize.text

import com.microsoft.ml.spark.core.contracts.{HasInputCol, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import org.apache.spark.ml._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.collection.mutable

object MultiNGram extends DefaultParamsReadable[MultiNGram]


class MultiNGram(override val uid: String)
  extends Transformer with HasInputCol with HasOutputCol with Wrappable with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("MultiNGram"))

  setDefault(outputCol, uid + "_output")

  val lengths =
    new ArrayParam(this, "lengths",
      "the collection of lengths to use for ngram extraction")

  def getLengths: Array[Int] = $(lengths)
    .toArray.map {
    case i: scala.math.BigInt => i.toInt
    case i: java.lang.Integer => i.toInt
  }

  def setLengths(v: Array[Int]): this.type = set(lengths, v)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val df = dataset.toDF()
    val intermediateOutputCols = getLengths.map(n =>
      DatasetExtensions.findUnusedColumnName(s"ngram_$n")(dataset.columns.toSet)
    )
    val models = getLengths.zip(intermediateOutputCols).map { case (n, out) =>
      new NGram().setN(n).setInputCol(getInputCol).setOutputCol(out)
    }
    val intermediateDF = NamespaceInjections.pipelineModel(models).transform(df)
    intermediateDF.map { row =>
      val mergedNGrams = intermediateOutputCols
        .map(col => row.getAs[Seq[String]](col))
        .reduce(_ ++ _)
      Row.merge(row, Row(mergedNGrams))
    }(RowEncoder(intermediateDF.schema.add(getOutputCol, ArrayType(StringType))))
      .drop(intermediateOutputCols: _*)
  }

  override def copy(extra: ParamMap): MultiNGram =
    defaultCopy(extra)

  def transformSchema(schema: StructType): StructType = {
    assert(schema(getInputCol).dataType == ArrayType(StringType))
    schema.add(getOutputCol, ArrayType(StringType))
  }

} 
Example 17
Source File: BigQueryAdapter.scala    From spark-bigquery   with Apache License 2.0 5 votes vote down vote up
package com.samelamin.spark.bigquery.converters

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.current_timestamp
import org.apache.spark.sql.types._


object BigQueryAdapter {

  private def adaptName(name: String, siblings: Array[String]): String = {
    var newName = name.replaceAll("\\W", "_")
    if (!newName.equals(name)) {
      // Avoid duplicates:
      var counter = 1;
      while (!siblings.find(_.equals(newName)).isEmpty) {
        newName = newName + "_" + counter
        counter = counter + 1
      }
    }
    newName
  }

  private def adaptField(structField: StructField, parentType: StructType): StructField = {
    new StructField(adaptName(structField.name, parentType.fieldNames), adaptType(structField.dataType), structField.nullable)
  }

  private def adaptType(dataType: DataType): DataType = {
    dataType match {
      case structType: StructType =>
        new StructType(structType.fields.map(adaptField(_, structType)))
      case arrayType: ArrayType =>
        new ArrayType(adaptType(arrayType.elementType), arrayType.containsNull)
      case mapType: MapType =>
        new MapType(adaptType(mapType.keyType), adaptType(mapType.valueType), mapType.valueContainsNull)
      case other => other
    }
  }

  def apply(df: DataFrame): DataFrame = {
    val sqlContext = df.sparkSession.sqlContext
    val sparkContext = df.sparkSession.sparkContext
    val timestampColumn = sparkContext
      .hadoopConfiguration.get("timestamp_column","bq_load_timestamp")
    val newSchema = adaptType(df.schema).asInstanceOf[StructType]
    val encoder = RowEncoder.apply(newSchema).resolveAndBind()
    val encodedDF = df
      .queryExecution
      .toRdd.map(x=>encoder.fromRow(x))
   sqlContext.createDataFrame(encodedDF,newSchema).withColumn(timestampColumn,current_timestamp())
  }
} 
Example 18
Source File: RecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.RecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class RecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 19
Source File: SparkRecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.SparkRecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class SparkRecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 20
Source File: RowStreamParserImp.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.streaming.parser

import java.text.SimpleDateFormat
import java.util

import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.processing.loading.ComplexDelimitersEnum
import org.apache.carbondata.processing.loading.constants.DataLoadProcessorConstants


class RowStreamParserImp extends CarbonStreamParser {

  var configuration: Configuration = null
  var isVarcharTypeMapping: Array[Boolean] = null
  var structType: StructType = null
  var encoder: ExpressionEncoder[Row] = null

  var timeStampFormat: SimpleDateFormat = null
  var dateFormat: SimpleDateFormat = null
  var complexDelimiters: util.ArrayList[String] = new util.ArrayList[String]()
  var serializationNullFormat: String = null

  override def initialize(configuration: Configuration,
      structType: StructType, isVarcharTypeMapping: Array[Boolean]): Unit = {
    this.configuration = configuration
    this.structType = structType
    this.encoder = RowEncoder.apply(this.structType).resolveAndBind()
    this.isVarcharTypeMapping = isVarcharTypeMapping

    this.timeStampFormat = new SimpleDateFormat(
      this.configuration.get(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT))
    this.dateFormat = new SimpleDateFormat(
      this.configuration.get(CarbonCommonConstants.CARBON_DATE_FORMAT))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_1"))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_2"))
    this.complexDelimiters.add(this.configuration.get("carbon_complex_delimiter_level_3"))
    this.complexDelimiters.add(ComplexDelimitersEnum.COMPLEX_DELIMITERS_LEVEL_4.value())
    this.serializationNullFormat =
      this.configuration.get(DataLoadProcessorConstants.SERIALIZATION_NULL_FORMAT)
  }

  override def parserRow(value: InternalRow): Array[Object] = {
    this.encoder.fromRow(value).toSeq.zipWithIndex.map { case (x, i) =>
      FieldConverter.objectToString(
        x, serializationNullFormat, complexDelimiters,
        timeStampFormat, dateFormat,
        isVarcharType = i < this.isVarcharTypeMapping.length && this.isVarcharTypeMapping(i),
        binaryCodec = null)
    } }.toArray

  override def close(): Unit = {
  }

} 
Example 21
Source File: GroupedIteratorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 22
Source File: GroupedIteratorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 23
Source File: ArrayDataIndexedSeqSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection}
import org.apache.spark.sql.types._

class ArrayDataIndexedSeqSuite extends SparkFunSuite {
  private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = {
    assert(arrayData.numElements == array.length)
    array.zipWithIndex.map { case (e, i) =>
      if (e != null) {
        elementDt match {
          // For NaN, etc.
          case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e))
          case _ => assert(arrayData.get(i, elementDt) === e)
        }
      } else {
        assert(arrayData.isNullAt(i))
      }
    }

    val seq = arrayData.toSeq[Any](elementDt)
    array.zipWithIndex.map { case (e, i) =>
      if (e != null) {
        elementDt match {
          // For Nan, etc.
          case FloatType | DoubleType => assert(seq(i).equals(e))
          case _ => assert(seq(i) === e)
        }
      } else {
        assert(seq(i) == null)
      }
    }

    intercept[IndexOutOfBoundsException] {
      seq(-1)
    }.getMessage().contains("must be between 0 and the length of the ArrayData.")

    intercept[IndexOutOfBoundsException] {
      seq(seq.length)
    }.getMessage().contains("must be between 0 and the length of the ArrayData.")
  }

  private def testArrayData(): Unit = {
    val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
      DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
      CalendarIntervalType, new ExamplePointUDT())
    val arrayTypes = elementTypes.flatMap { elementType =>
      Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
    }
    val random = new Random(100)
    arrayTypes.foreach { dt =>
      val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil)
      val row = RandomDataGenerator.randomRow(random, schema)
      val rowConverter = RowEncoder(schema)
      val internalRow = rowConverter.toRow(row)

      val unsafeRowConverter = UnsafeProjection.create(schema)
      val safeRowConverter = FromUnsafeProjection(schema)

      val unsafeRow = unsafeRowConverter(internalRow)
      val safeRow = safeRowConverter(unsafeRow)

      val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
      val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]

      val elementType = dt.elementType
      test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
        compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType))
      }

      test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
        compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType))
      }
    }
  }

  testArrayData()
} 
Example 24
Source File: StreamingIncrementCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import java.util.Locale

import org.apache.spark.SparkException
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.StreamingRelationV2
import org.apache.spark.sql.sources.v2.StreamWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.xsql.DataSourceManager._
import org.apache.spark.sql.xsql.StreamingSinkType


case class StreamingIncrementCommand(plan: LogicalPlan) extends RunnableCommand {

  private var outputMode: OutputMode = OutputMode.Append
  // dummy
  override def output: Seq[AttributeReference] = Seq.empty
  // dummy
  override def producedAttributes: AttributeSet = plan.producedAttributes

  override def run(sparkSession: SparkSession): Seq[Row] = {
    import StreamingSinkType._
    val qe = new QueryExecution(sparkSession, new ConstructedStreaming(plan))
    val df = new Dataset(sparkSession, qe, RowEncoder(qe.analyzed.schema))
    plan.collectLeaves.head match {
      case StreamingRelationV2(_, _, extraOptions, _, _) =>
        val source = extraOptions.getOrElse(STREAMING_SINK_TYPE, DEFAULT_STREAMING_SINK)
        val sinkOptions = extraOptions.filter(_._1.startsWith(STREAMING_SINK_PREFIX)).map { kv =>
          val key = kv._1.substring(STREAMING_SINK_PREFIX.length)
          (key, kv._2)
        }
        StreamingSinkType.withName(source.toUpperCase(Locale.ROOT)) match {
          case CONSOLE =>
          case TEXT | PARQUET | ORC | JSON | CSV =>
            if (sinkOptions.get(STREAMING_SINK_PATH) == None) {
              throw new SparkException("Sink type is file, must config path")
            }
          case KAFKA =>
            if (sinkOptions.get(STREAMING_SINK_BOOTSTRAP_SERVERS) == None) {
              throw new SparkException("Sink type is kafka, must config bootstrap servers")
            }
            if (sinkOptions.get(STREAMING_SINK_TOPIC) == None) {
              throw new SparkException("Sink type is kafka, must config kafka topic")
            }
          case _ =>
            throw new SparkException(
              "Sink type is invalid, " +
                s"select from ${StreamingSinkType.values}")
        }
        val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
        val disabledSources = sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
        val sink = ds.newInstance() match {
          case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) =>
            w
          case _ =>
            val ds = DataSource(
              sparkSession,
              className = source,
              options = sinkOptions.toMap,
              partitionColumns = Nil)
            ds.createSink(InternalOutputModes.Append)
        }
        val outputMode = InternalOutputModes(
          extraOptions.getOrElse(STREAMING_OUTPUT_MODE, DEFAULT_STREAMING_OUTPUT_MODE))
        val duration =
          extraOptions.getOrElse(STREAMING_TRIGGER_DURATION, DEFAULT_STREAMING_TRIGGER_DURATION)
        val trigger =
          extraOptions.getOrElse(STREAMING_TRIGGER_TYPE, DEFAULT_STREAMING_TRIGGER_TYPE) match {
            case STREAMING_MICRO_BATCH_TRIGGER => Trigger.ProcessingTime(duration)
            case STREAMING_ONCE_TRIGGER => Trigger.Once()
            case STREAMING_CONTINUOUS_TRIGGER => Trigger.Continuous(duration)
          }
        val query = sparkSession.sessionState.streamingQueryManager.startQuery(
          extraOptions.get("queryName"),
          extraOptions.get(STREAMING_CHECKPOINT_LOCATION),
          df,
          sinkOptions.toMap,
          sink,
          outputMode,
          useTempCheckpointLocation = source == DEFAULT_STREAMING_SINK,
          recoverFromCheckpointLocation = true,
          trigger = trigger)
        query.awaitTermination()
    }
    // dummy
    Seq.empty
  }
}

case class ConstructedStreaming(child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = child.output
} 
Example 25
Source File: DFConverter.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import com.twosigma.flint.rdd.OrderedRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType


object DFConverter {

  def newDataFrame(df: DataFrame): DataFrame = {
    new DataFrame(df.sparkSession, df.logicalPlan, RowEncoder(df.schema))
  }

  def toDataFrame(rdd: OrderedRDD[Long, InternalRow], schema: StructType): DataFrame = {
    val spark = SparkSession.builder().getOrCreate()
    val internalRows = rdd.values
    spark.internalCreateDataFrame(internalRows, schema)
  }

  def toDataFrame(rdd: RDD[InternalRow], schema: StructType): DataFrame = {
    val spark = SparkSession.builder().getOrCreate()
    spark.internalCreateDataFrame(rdd, schema)
  }

} 
Example 26
Source File: S2StreamQueryWriter.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.spark.sql.streaming

import com.typesafe.config.ConfigFactory
import org.apache.s2graph.core.{GraphElement, JSONParser}
import org.apache.s2graph.s2jobs.S2GraphHelper
import org.apache.s2graph.spark.sql.streaming.S2SinkConfigs._
import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.types.StructType
import play.api.libs.json.{JsObject, Json}

import scala.collection.mutable.ListBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.util.Try

private [sql] class S2StreamQueryWriter(
                                         serializedConf:String,
                                         schema: StructType ,
                                         commitProtocol: S2CommitProtocol
                                       ) extends Serializable with Logger {
  private val config = ConfigFactory.parseString(serializedConf)
  private val s2Graph = S2GraphHelper.getS2Graph(config)
  private val encoder: ExpressionEncoder[Row] = RowEncoder(schema).resolveAndBind()
  private val RESERVED_COLUMN = Set("timestamp", "from", "to", "label", "operation", "elem", "direction")


  def run(taskContext: TaskContext, iters: Iterator[InternalRow]): TaskCommit = {
    val taskId = s"stage-${taskContext.stageId()}, partition-${taskContext.partitionId()}, attempt-${taskContext.taskAttemptId()}"
    val partitionId= taskContext.partitionId()

    val groupedSize = getConfigString(config, S2_SINK_GROUPED_SIZE, DEFAULT_GROUPED_SIZE).toInt
    val waitTime = getConfigString(config, S2_SINK_WAIT_TIME, DEFAULT_WAIT_TIME_SECONDS).toInt

    commitProtocol.initTask()
    try {
      var list = new ListBuffer[(String, Int)]()
      val rst = iters.flatMap(rowToEdge).grouped(groupedSize).flatMap{ elements =>
        logger.debug(s"[$taskId][elements] ${elements.size} (${elements.map(e => e.toLogString).mkString(",\n")})")
        elements.groupBy(_.serviceName).foreach{ case (service, elems) =>
          list += ((service, elems.size))
        }

        val mutateF = s2Graph.mutateElements(elements, true)
        Await.result(mutateF, Duration(waitTime, "seconds"))
      }

      val (success, fail) = rst.toSeq.partition(r => r.isSuccess)
      val counter = list.groupBy(_._1).map{ case (service, t) =>
        val sum = t.toList.map(_._2).sum
        (service, sum)
      }
      logger.info(s"[$taskId] success : ${success.size}, fail : ${fail.size} ($counter)")


      commitProtocol.commitTask(TaskState(partitionId, success.size, fail.size, counter))

    } catch {
      case t: Throwable =>
        commitProtocol.abortTask(TaskState(partitionId))
        throw t
    }
  }

  private def rowToEdge(internalRow:InternalRow): Option[GraphElement] =
    S2GraphHelper.sparkSqlRowToGraphElement(s2Graph, encoder.fromRow(internalRow), schema, RESERVED_COLUMN)
}