org.apache.spark.TaskContext Scala Examples

Example 1
package org.apache.spark.sql.execution.datasources.text

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration

  def getCompressionExtension(context: TaskAttemptContext): String = {
    // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
    if (FileOutputFormat.getCompressOutput(context)) {
      val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
      ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
    } else {
Example 2
Source File: SnowflakeRDD.scala    From spark-snowflake   with Apache License 2.0 5 votes vote down vote up


import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD

class SnowflakeRDD(sc: SparkContext,
                   fileNames: List[String],
                   format: SupportedFormat,
                   downloadFile: String => InputStream,
                   expectedPartitionCount: Int)
    extends RDD[String](sc, Nil) {

  @transient private val MIN_FILES_PER_PARTITION = 2
  @transient private val MAX_FILES_PER_PARTITION = 10

  override def compute(split: Partition,
                       context: TaskContext): Iterator[String] = {
    val snowflakePartition = split.asInstanceOf[SnowflakePartition]

    val stringIterator = new SFRecordReader(format, snowflakePartition.index)

    snowflakePartition.fileNames.foreach(name => {
      s"""${SnowflakeResultSetRDD.WORKER_LOG_PREFIX}: Start reading
         | partition ID:${snowflakePartition.index}
         | totalFileCount=${snowflakePartition.fileNames.size}
         |""".stripMargin.filter(_ >= ' '))


  override protected def getPartitions: Array[Partition] = {
    var fileCountPerPartition =
        (fileNames.length + expectedPartitionCount / 2) / expectedPartitionCount
    fileCountPerPartition = Math.min(MAX_FILES_PER_PARTITION, fileCountPerPartition)
    val fileCount = fileNames.length
    val partitionCount = (fileCount + fileCountPerPartition - 1) / fileCountPerPartition"""${SnowflakeResultSetRDD.MASTER_LOG_PREFIX}: Total statistics:
         | fileCount=$fileCount filePerPartition=$fileCountPerPartition
         | actualPartitionCount=$partitionCount
         | expectedPartitionCount=$expectedPartitionCount
         |""".stripMargin.filter(_ >= ' '))

    if (fileNames.nonEmpty) {
        .map {
          case (names, index) => SnowflakePartition(names, id, index)
    } else {
      // If the result set is empty, put one empty partition to the array.
      Seq[SnowflakePartition]{SnowflakePartition(fileNames, 0, 0)}.toArray


private case class SnowflakePartition(fileNames: List[String],
                                      rddId: Int,
                                      index: Int)
    extends Partition {

  override def hashCode(): Int = 31 * (31 + rddId) + index

  override def equals(other: Any): Boolean = super.equals(other)
Example 3
Source File: SlidingRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
    "Window size and step must be greater than 0, " +
      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)
      .sliding(windowSize, step)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.length
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
    } else {
      val w1 = windowSize - 1
      // Get partition sizes and first w1 elements.
      val (sizes, heads) = parent.mapPartitions { iter =>
        val w1Array = iter.take(w1).toArray
        Iterator.single((w1Array.length + iter.length, w1Array))
      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
      var i = 0
      var cumSize = 0
      var partitionIndex = 0
      while (i < n) {
        val mod = cumSize % step
        val offset = if (mod == 0) 0 else step - mod
        val size = sizes(i)
        if (offset < size) {
          val tail = mutable.ListBuffer.empty[T]
          // Keep appending to the current tail until it has w1 elements.
          var j = i + 1
          while (j < n && tail.length < w1) {
            tail ++= heads(j).take(w1 - tail.length)
            j += 1
          if (sizes(i) + tail.length >= offset + windowSize) {
            partitions +=
              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
            partitionIndex += 1
        cumSize += size
        i += 1

  // TODO: Override methods such as aggregate, which only requires one Spark job.
Example 4
Source File: CachedKafkaConsumer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer}
import org.apache.kafka.common.TopicPartition

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging

  def getOrCreate(
      topic: String,
      partition: Int,
      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized {
    val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
    val topicPartition = new TopicPartition(topic, partition)
    val key = CacheKey(groupId, topicPartition)

    // If this is reattempt at running the task, then invalidate cache and start with
    // a new consumer
    if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
      new CachedKafkaConsumer(topicPartition, kafkaParams)
    } else {
      if (!cache.containsKey(key)) {
        cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams))
Example 5
Source File: CommitFailureTestSource.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          stagingDir: String,
          fileNamePrefix: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true

          override val path: String = new Path(stagingDir, fileNamePrefix).toString

          override def write(row: Row): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")


          override def close(): Unit = {
            sys.error("Intentional task commitment failure for testing purpose.")

  override def shortName(): String = "commit-failure-test"
Example 6
Source File: MonotonicallyIncreasingID.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}

  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initInternal(): Unit = {
    count = 0L
    partitionMask = TaskContext.getPartitionId().toLong << 33

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
Example 7
Source File: randomExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

  usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.")
case class Randn(seed: Long) extends RDG {
  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  def this() = this(Utils.random.nextLong())

  def this(seed: Expression) = this(seed match {
    case IntegerLiteral(s) => s
    case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val rngTerm = ctx.freshName("rng")
    val className = classOf[XORShiftRandom].getName
    ctx.addMutableState(className, rngTerm,
      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
Example 8
Source File: ShuffledHashJoinExec.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class ShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

  override def requiredChildDistribution: Seq[Distribution] =
    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
    buildTime += (System.nanoTime() - start) / 1000000
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener(_ => relation.close())

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows)
Example 9
Source File: StateStoreRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StateStoreRDD[T: ClassTag, U: ClassTag](
    dataRDD: RDD[T],
    storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
    checkpointLocation: String,
    operatorId: Long,
    storeVersion: Long,
    keySchema: StructType,
    valueSchema: StructType,
    sessionState: SessionState,
    @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
  extends RDD[U](dataRDD) {

  private val storeConf = new StateStoreConf(sessionState.conf)

  // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
  private val confBroadcast = dataRDD.context.broadcast(
    new SerializableConfiguration(sessionState.newHadoopConf()))

  override protected def getPartitions: Array[Partition] = dataRDD.partitions

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
    store = StateStore.get(
      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
Example 10
Source File: ReferenceSort.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryExecNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = child.outputPartitioning
Example 11
Source File: SparkHadoopMapRedUtil.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mapred


import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging

object SparkHadoopMapRedUtil extends Logging {
  def commitTask(
      committer: MapReduceOutputCommitter,
      mrTaskContext: MapReduceTaskAttemptContext,
      jobId: Int,
      splitId: Int): Unit = {

    val mrTaskAttemptID = mrTaskContext.getTaskAttemptID

    // Called after we have decided to commit
    def performCommit(): Unit = {
      try {
        logInfo(s"$mrTaskAttemptID: Committed")
      } catch {
        case cause: IOException =>
          logError(s"Error committing the output of task: $mrTaskAttemptID", cause)
          throw cause

    // First, check whether the task's output has already been committed by some other attempt
    if (committer.needsTaskCommit(mrTaskContext)) {
      val shouldCoordinateWithDriver: Boolean = {
        val sparkConf = SparkEnv.get.conf
        // We only need to coordinate with the driver if there are concurrent task attempts.
        // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)

      if (shouldCoordinateWithDriver) {
        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
        val taskAttemptNumber = TaskContext.get().attemptNumber()
        val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber)

        if (canCommit) {
        } else {
          val message =
            s"$mrTaskAttemptID: Not committed because the driver did not authorize commit"
          // We need to abort the task so that the driver can reschedule new attempts, if necessary
          throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber)
      } else {
        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
    } else {
      // Some other attempt committed the output, so we do nothing and signal success
      logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
Example 12
Source File: taskListeners.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.EventListener

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi

class TaskCompletionListenerException(
    errorMessages: Seq[String],
    val previousError: Option[Throwable] = None)
  extends RuntimeException {

  override def getMessage: String = {
    if (errorMessages.size == 1) {
    } else { { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
    } + { e =>
      "\n\nPrevious exception in task: " + e.getMessage + "\n" +
        e.getStackTrace.mkString("\t", "\n\t", "")
Example 13
Source File: SubtractedRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.{HashMap => JHashMap}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.Dependency
import org.apache.spark.OneToOneDependency
import org.apache.spark.Partition
import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext

private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
    @transient var rdd1: RDD[_ <: Product2[K, V]],
    @transient var rdd2: RDD[_ <: Product2[K, W]],
    part: Partitioner)
  extends RDD[(K, V)](rdd1.context, Nil) {

  override def getDependencies: Seq[Dependency[_]] = {
    def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]])
      : Dependency[_] = {
      if (rdd.partitioner == Some(part)) {
        logDebug("Adding one-to-one dependency with " + rdd)
        new OneToOneDependency(rdd)
      } else {
        logDebug("Adding shuffle dependency with " + rdd)
        new ShuffleDependency[T1, T2, Any](rdd, part)
    Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2))

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](part.numPartitions)
    for (i <- 0 until array.length) {
      // Each CoGroupPartition will depend on rdd1 and rdd2
      array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2) { case (rdd, j) =>
        dependencies(j) match {
          case s: ShuffleDependency[_, _, _] =>
          case _ =>
            Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))

  override val partitioner = Some(part)

  override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
    val partition = p.asInstanceOf[CoGroupPartition]
    val map = new JHashMap[K, ArrayBuffer[V]]
    def getSeq(k: K): ArrayBuffer[V] = {
      val seq = map.get(k)
      if (seq != null) {
      } else {
        val seq = new ArrayBuffer[V]()
        map.put(k, seq)
    def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = {
      dependencies(depNum) match {
        case oneToOneDependency: OneToOneDependency[_] =>
          val dependencyPartition = partition.narrowDeps(depNum).get.split
          oneToOneDependency.rdd.iterator(dependencyPartition, context)
            .asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

        case shuffleDependency: ShuffleDependency[_, _, _] =>
          val iter = SparkEnv.get.shuffleManager
              shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)

    // the first dep is rdd1; add all values to the map
    integrate(0, t => getSeq(t._1) += t._2)
    // the second dep is rdd2; remove all of its keys
    integrate(1, t => map.remove(t._1)) =>, _))).flatten

  override def clearDependencies() {
    rdd1 = null
    rdd2 = null

Example 14
Source File: ZippedWithIndexRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    val parentIter = firstParent[T].iterator(split.prev, context)
    Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
Example 15
Source File: UnionRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
    } else {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Example 16
Source File: PartitionwiseSampledRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    preservesPartitioning: Boolean,
    @transient private val seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Example 17
Source File: PartitionerAwareUnionRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map { index =>
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) =>
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Example 18
Source File: MemoryTestingUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.memory

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}

object MemoryTestingUtils {
  def fakeTaskContext(env: SparkEnv): TaskContext = {
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
    new TaskContextImpl(
      stageId = 0,
      partitionId = 0,
      taskAttemptId = 0,
      attemptNumber = 0,
      _taskMemoryManager = taskMemoryManager,
      localProperties = new Properties,
      metricsSystem = env.metricsSystem)
Example 19
Source File: FakeTask.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics

class FakeTask(
    stageId: Int,
    partitionId: Int,
    prefLocs: Seq[TaskLocation] = Nil,
    serializedTaskMetrics: Array[Byte] =
  extends Task[Int](stageId, 0, partitionId, serializedTaskMetrics) {

  override def prepTask(): Unit = {}
  override def runTask(context: TaskContext): Int = 0
  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
  TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
Example 20
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Seconds, Span}

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Example 21
Source File: PartitionPruningRDDSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Example 22
Source File: ParameterSynchronizer.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up

import java.util.concurrent.{ConcurrentHashMap, CyclicBarrier}

import org.apache.spark.TaskContext

import scala.reflect._

  def reset(): Unit = {
    if (data.size != 0) {
      data.synchronized {
        if (data.size != 0) {
Example 23
Source File: DistributedSynchronizerSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.{SparkContext, TaskContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class DistributedSynchronizerSpec extends FlatSpec with Matchers with BeforeAndAfter {

  var sc: SparkContext = null

  before {
    val conf = Engine.createSparkConf().setAppName("test synchronizer").setMaster("local[4]")
      .set("spark.rpc.message.maxSize", "200")
    sc = new SparkContext(conf)

  "DistributedSynchronizer" should "work properly" in {
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](10).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 10, weights = null, grads = tensor)
      var res : Iterator[_] = null
      res = Iterator.single(sync.get(s"testPara"))
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))

  "DistributedSynchronizer with parameter size less than partition" should "work properly" in {
    val cores1 = Runtime.getRuntime().availableProcessors
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](2).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 2, weights = null, grads = tensor)
      var res : Iterator[_] = null
      res = Iterator.single(sync.get(s"testPara"))
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))

  "DistributedSynchronizer with parameter offset > 1" should "work properly" in {
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](20)
      val parameter = tensor.narrow(1, 10, 10).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 10, weights = null, grads = parameter)
      var res : Iterator[_] = null
      res = Iterator.single(sync.get(s"testPara"))
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))

  after {
Example 24
Source File: SparkUtil.scala    From ArchiveSpark   with MIT License 5 votes vote down vote up
package org.archive.archivespark.sparkling.util

import org.apache.spark.TaskContext

object SparkUtil {
  private var cleanupObjects = collection.mutable.Map.empty[Any, Long]
  private var cleanups = collection.mutable.Map.empty[Long, collection.mutable.Map[Any, () => Unit]]

  def cleanupTask(owner: Any, cleanup: () => Unit): Unit = {
    val task = TaskContext.get
    if (task != null) {
      cleanupObjects.getOrElseUpdate(owner, {
        val attemptId = task.taskAttemptId
        val taskCleanups = cleanups.getOrElseUpdate(attemptId, {
          val taskCleanups = collection.mutable.Map.empty[Any, () => Unit]
          task.addTaskCompletionListener(ctx => {
            for ((o, c) <- taskCleanups) {
        taskCleanups.update(owner, cleanup)

  def removeTaskCleanup(owner: Any): Unit = for (attemptId <- cleanupObjects.remove(owner)) cleanups(attemptId).remove(owner)
Example 25
Source File: S2StreamQueryWriter.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.spark.sql.streaming

import com.typesafe.config.ConfigFactory
import org.apache.s2graph.core.{GraphElement, JSONParser}
import org.apache.s2graph.s2jobs.S2GraphHelper
import org.apache.s2graph.spark.sql.streaming.S2SinkConfigs._
import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.types.StructType
import play.api.libs.json.{JsObject, Json}

import scala.collection.mutable.ListBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.util.Try

private [sql] class S2StreamQueryWriter(
                                         schema: StructType ,
                                         commitProtocol: S2CommitProtocol
                                       ) extends Serializable with Logger {
  private val config = ConfigFactory.parseString(serializedConf)
  private val s2Graph = S2GraphHelper.getS2Graph(config)
  private val encoder: ExpressionEncoder[Row] = RowEncoder(schema).resolveAndBind()
  private val RESERVED_COLUMN = Set("timestamp", "from", "to", "label", "operation", "elem", "direction")

  def run(taskContext: TaskContext, iters: Iterator[InternalRow]): TaskCommit = {
    val taskId = s"stage-${taskContext.stageId()}, partition-${taskContext.partitionId()}, attempt-${taskContext.taskAttemptId()}"
    val partitionId= taskContext.partitionId()

    val groupedSize = getConfigString(config, S2_SINK_GROUPED_SIZE, DEFAULT_GROUPED_SIZE).toInt
    val waitTime = getConfigString(config, S2_SINK_WAIT_TIME, DEFAULT_WAIT_TIME_SECONDS).toInt

    try {
      var list = new ListBuffer[(String, Int)]()
      val rst = iters.flatMap(rowToEdge).grouped(groupedSize).flatMap{ elements =>
        logger.debug(s"[$taskId][elements] ${elements.size} (${ => e.toLogString).mkString(",\n")})")
        elements.groupBy(_.serviceName).foreach{ case (service, elems) =>
          list += ((service, elems.size))

        val mutateF = s2Graph.mutateElements(elements, true)
        Await.result(mutateF, Duration(waitTime, "seconds"))

      val (success, fail) = rst.toSeq.partition(r => r.isSuccess)
      val counter = list.groupBy(_._1).map{ case (service, t) =>
        val sum =
        (service, sum)
      }"[$taskId] success : ${success.size}, fail : ${fail.size} ($counter)")

      commitProtocol.commitTask(TaskState(partitionId, success.size, fail.size, counter))

    } catch {
      case t: Throwable =>
        throw t

  private def rowToEdge(internalRow:InternalRow): Option[GraphElement] =
    S2GraphHelper.sparkSqlRowToGraphElement(s2Graph, encoder.fromRow(internalRow), schema, RESERVED_COLUMN)
Example 26
Source File: S2SparkSqlStreamingSink.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.spark.sql.streaming

import java.util.UUID

import com.typesafe.config.{Config, ConfigRenderOptions}
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.streaming.{MetadataLog, Sink}
import org.apache.spark.sql.{DataFrame, SparkSession}

class S2SparkSqlStreamingSink(
                               sparkSession: SparkSession,
                             ) extends Sink with Logger {
  import S2SinkConfigs._

  private val APP_NAME = "s2graph"

  private val writeLog: MetadataLog[Array[S2SinkStatus]] = {
    val logPath = getCommitLogPath(config)"MetaDataLogPath: $logPath")

    new S2SinkMetadataLog(sparkSession, config, logPath)

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    logger.debug(s"addBatch : $batchId, getLatest : ${writeLog.getLatest()}")

    if (batchId <= writeLog.getLatest().map(_._1).getOrElse(-1L)) {"Skipping already committed batch [$batchId]")
    } else {
      val queryName = getConfigStringOpt(config, "queryname").getOrElse(UUID.randomUUID().toString)
      val commitProtocol = new S2CommitProtocol(writeLog)
      val jobState = JobState(queryName, batchId)
      val serializedConfig = config.root().render(ConfigRenderOptions.concise())
      val queryExecution = data.queryExecution
      val schema = data.schema

      SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
        try {
          val taskCommits = sparkSession.sparkContext.runJob(queryExecution.toRdd,
            (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
              new S2StreamQueryWriter(serializedConfig, schema, commitProtocol).run(taskContext, iter)
          commitProtocol.commitJob(jobState, taskCommits)
        } catch {
          case t: Throwable =>
            throw t;


  private def getCommitLogPath(config:Config): String = {
    val logPathOpt = getConfigStringOpt(config, S2_SINK_LOG_PATH)
    val userCheckpointLocationOpt = getConfigStringOpt(config, S2_SINK_CHECKPOINT_LOCATION)

    (logPathOpt, userCheckpointLocationOpt) match {
      case (Some(logPath), _) => logPath
      case (None, Some(userCheckpoint)) => s"$userCheckpoint/sinks/$APP_NAME"
      case _ => throw new IllegalArgumentException(s"failed to get commit log path")

  override def toString(): String = "S2GraphSink"
Example 27
Source File: RiakWriterTaskCompletionListener.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.riak

import org.apache.spark.TaskContext
import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
import org.apache.spark.util.TaskCompletionListener

class RiakWriterTaskCompletionListener(recordsWritten: Long) extends TaskCompletionListener{

  override def onTaskCompletion(context: TaskContext): Unit = {
    val metrics = OutputMetrics(DataWriteMethod.Hadoop)
    context.taskMetrics().outputMetrics = Some(metrics)


object RiakWriterTaskCompletionListener {
  def apply(recordsWritten: Long) = new RiakWriterTaskCompletionListener(recordsWritten)
Example 28
Source File: RiakTSRDDTest.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.spark.rdd.timeseries

import com.basho.riak.client.core.query.timeseries.ColumnDescription.ColumnType
import com.basho.riak.client.core.query.timeseries.{Cell, ColumnDescription}
import com.basho.riak.spark.query.{QueryTS, TSQueryData}
import com.basho.riak.spark.rdd.{RegressionTests, RiakTSRDD}
import com.basho.riak.spark.rdd.connector.RiakConnector
import com.basho.riak.spark.rdd.partitioner.RiakTSPartition
import org.apache.spark.{SparkContext, TaskContext}
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.powermock.modules.junit4.PowerMockRunner
import org.powermock.api.mockito.PowerMockito
import com.basho.riak.client.core.query.timeseries.{Row => RiakRow}
import org.apache.spark.sql.{Row => SparkRow}
import org.junit.experimental.categories.Category
import org.mockito.Mock
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.powermock.core.classloader.annotations.PrepareForTest

import scala.collection.JavaConversions._

class RiakTSRDDTest {

  protected val rc: RiakConnector = null

  protected val sc: SparkContext = null

  protected val tc: TaskContext = null

  def readAll(): Unit = {
    val rdd = new RiakTSRDD[SparkRow](sc, rc, "test")

    val neTsQD = TSQueryData("non-empty", None)
    val nonEmptyResponse = (Seq(new ColumnDescription("Col1", ColumnType.VARCHAR), new ColumnDescription("Col2", ColumnType.SINT64)),
      Seq(new RiakRow(List(new Cell("string-value"), new Cell(42)))))

    PowerMockito.whenNew(classOf[QueryTS]).withAnyArguments().thenAnswer( new Answer[QueryTS]{
      override def answer(invocation: InvocationOnMock): QueryTS = {
        val args = invocation.getArguments
        val rc: RiakConnector = args(0).asInstanceOf[RiakConnector]
        val qd: Seq[TSQueryData] = args(1).asInstanceOf[Seq[TSQueryData]]

        val q = spy(QueryTS(rc, qd))

        // By default returns an empty result
        doReturn(Seq() -> Seq()).when(q).nextChunk(any[TSQueryData])

        // Return 1 row for non-empty result


    // ----------  Perform test
    val iterator = rdd.compute( RiakTSPartition(0, Nil, List( null, null, null, neTsQD, neTsQD, null, neTsQD, null)), tc)

    // ----------  verify results
    val seq: Seq[SparkRow] = iterator.toIndexedSeq

    assertEquals(3, seq.size)

    seq.foreach(r => {
      assertEquals(2, r.size)
      assertEquals("string-value", r.get(0))
      assertEquals(42l, r.get(1))
Example 29
Source File: ArrowConverters.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.arrow

import java.nio.channels.Channels

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
import com.twosigma.flint.util.Utils
import org.apache.arrow.vector.ipc.{ ArrowFileReader, ArrowFileWriter }
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch

trait ClosableIterator[T] extends Iterator[T] with AutoCloseable

class ConcatClosableIterator[T](iters: Iterator[ClosableIterator[T]])
  extends ClosableIterator[T] {
  var curIter: ClosableIterator[T] = _

  private def advance(): Unit = {
    require(curIter == null || !curIter.hasNext, "Should not advance if curIter is not empty")
    require(iters.hasNext, "Should not advance if iters doesn't have next")
    curIter =

  private def closeCurrent(): Unit = if (curIter != null) curIter.close()

  override def close(): Unit = closeCurrent()

  override def hasNext: Boolean = {
    if (curIter == null || !curIter.hasNext) {
      if (iters.hasNext) {
      } else {
    } else {

  override def next(): T =

  def byteArrayToBatch(
    batchBytes: Array[Byte],
    allocator: BufferAllocator
  ): ArrowRecordBatch = {
    val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
    val reader = new ArrowFileReader(in, allocator)

    // Read a batch from a byte stream, ensure the reader is closed
    Utils.tryWithSafeFinally {
      val root = reader.getVectorSchemaRoot
      // throws IOException
      val unloader = new VectorUnloader(root)
      reader.loadNextBatch() // throws IOException
    } {
Example 30
Source File: PartitionsIterator.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import grizzled.slf4j.Logger

import org.apache.spark.rdd.RDD
import org.apache.spark.{ Partition, TaskContext }

protected[flint] object PartitionsIterator {
  val logger = Logger(PartitionsIterator.getClass)

  def apply[T](
    rdd: RDD[T],
    partitions: Seq[Partition],
    context: TaskContext,
    preservesPartitionsOrdering: Boolean = false // FIXME: This is a band-aid which should be fixed.
  ): PartitionsIterator[T] = new PartitionsIterator(rdd, partitions, context, preservesPartitionsOrdering)

  def headPartitionIndex: Int = curPart.index
Example 31
Source File: SummarizeByKeyIterator.scala    From flint   with Apache License 2.0 5 votes vote down vote up

import java.util

import com.twosigma.flint.rdd.function.summarize.summarizer.Summarizer
import com.twosigma.flint.rdd.function.window.SummarizeWindows
import org.apache.spark.TaskContext

import scala.reflect.ClassTag
import scala.collection.JavaConverters._

private[rdd] class SummarizeByKeyIterator[K, V, SK, U, V2](
  iter: Iterator[(K, V)],
  skFn: V => SK,
  summarizer: Summarizer[V, U, V2]
)(implicit tag: ClassTag[V], ord: Ordering[K])
  extends Iterator[(K, (SK, V2))]
  with AutoCloseable {
  private[this] val bufferedIter = iter.buffered

  private[this] var currentKey: K = _

  // We use a mutable linked hash map in order to preserve the secondary key ordering.
  private[this] val intermediates: util.LinkedHashMap[SK, U] =
    new util.LinkedHashMap()

  override def hasNext: Boolean =
    !intermediates.isEmpty || bufferedIter.hasNext

  // Update intermediates with next key if bufferedIter.hasNext.
  private def nextKey(): Unit = if (bufferedIter.hasNext) {
    currentKey = bufferedIter.head._1
    // Iterates through all rows from the given iterator until seeing a different key.
    do {
      val v =
      val sk = skFn(v)
      val intermediate = SummarizeWindows.lazyGetOrDefault(intermediates, sk,
      intermediates.put(sk, summarizer.add(intermediate, v))
    } while (bufferedIter.hasNext && ord.equiv(bufferedIter.head._1, currentKey))

  override def next(): (K, (SK, V2)) = {
    if (intermediates.isEmpty) {
    if (hasNext) {
      val entry = intermediates.entrySet().iterator().next()
      val sk = entry.getKey
      val intermediate = entry.getValue
      (currentKey, (sk, summarizer.render(intermediate)))
    } else {

  override def close(): Unit = intermediates.asScala.toMap.values.foreach {
    u =>
Example 32
Source File: ParallelCollectionRDD.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import org.apache.spark.rdd.RDD
import org.apache.spark.{ Partition, SparkContext, TaskContext }

import scala.reflect.ClassTag

case class ParallelCollectionRDDPartition[T: ClassTag](
  override val index: Int,
  values: Seq[T]
) extends Partition

class ParallelCollectionRDD[T: ClassTag](
  sc: SparkContext,
  @transient data: Seq[Seq[T]]
) extends RDD[T](sc, Nil) {
  override def compute(split: Partition, context: TaskContext): Iterator[T] =

  override protected def getPartitions: Array[Partition] = {
      case (d, index) =>
        ParallelCollectionRDDPartition(index, d)
Example 33
Source File: TiHandleRDD.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.tispark

import com.pingcap.tikv.meta.TiDAGRequest
import com.pingcap.tikv.util.RangeSplitter
import com.pingcap.tikv.{TiConfiguration, TiSession}
import com.pingcap.tispark.utils.TiUtil
import com.pingcap.tispark.{TiPartition, TiTableReference}
import gnu.trove.list.array.TLongArrayList
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.{Partition, TaskContext, TaskKilledException}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

class TiHandleRDD(
    override val dagRequest: TiDAGRequest,
    override val physicalId: Long,
    val output: Seq[Attribute],
    override val tiConf: TiConfiguration,
    override val tableRef: TiTableReference,
    @transient private val session: TiSession,
    @transient private val sparkSession: SparkSession)
    extends TiRDD(dagRequest, physicalId, tiConf, tableRef, session, sparkSession) {

  private val outputTypes =
  private val converters =

  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] =
    new Iterator[InternalRow] {

      private val tiPartition = split.asInstanceOf[TiPartition]
      private val session = TiSession.getInstance(tiConf)
      private val snapshot = session.createSnapshot(dagRequest.getStartTs)
      private[this] val tasks = tiPartition.tasks

      private val handleIterator = snapshot.indexHandleRead(dagRequest, tasks)
      private val regionManager = session.getRegionManager
      private lazy val handleList = {
        val lst = new TLongArrayList()
        handleIterator.asScala.foreach {
          // Kill the task in case it has been marked as killed. This logic is from
          // InterruptedIterator, but we inline it here instead of wrapping the iterator in order
          // to avoid performance overhead.
          if (context.isInterrupted()) {
            throw new TaskKilledException
      // Fetch all handles and group by region id
      private val regionHandleMap = RangeSplitter
        .groupByAndSortHandlesByRegionId(physicalId, handleList)
        .map(x => (x._1.first.getId, x._2))

      private val iterator = regionHandleMap.iterator

      override def hasNext: Boolean = {
        // Kill the task in case it has been marked as killed.
        if (context.isInterrupted()) {
          throw new TaskKilledException

      override def next(): InternalRow = {
        val next =
        val regionId = next._1
        val handleList = next._2

        // Returns RegionId:[handle1, handle2, handle3...] K-V pair
        val sparkRow = Row.apply(regionId, handleList.toArray())
        TiUtil.rowToInternalRow(sparkRow, outputTypes, converters)
Example 34
Source File: TiRowRDD.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.tispark

import com.pingcap.tikv._
import com.pingcap.tikv.columnar.TiColumnarBatchHelper
import com.pingcap.tikv.meta.TiDAGRequest
import com.pingcap.tispark.listener.CacheInvalidateListener
import com.pingcap.tispark.{TiPartition, TiTableReference}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.{Partition, TaskContext, TaskKilledException}
import org.slf4j.Logger

import scala.collection.JavaConversions._

class TiRowRDD(
    override val dagRequest: TiDAGRequest,
    override val physicalId: Long,
    val chunkBatchSize: Int,
    override val tiConf: TiConfiguration,
    val output: Seq[Attribute],
    override val tableRef: TiTableReference,
    @transient private val session: TiSession,
    @transient private val sparkSession: SparkSession)
    extends TiRDD(dagRequest, physicalId, tiConf, tableRef, session, sparkSession) {

  protected val logger: Logger = log

  // cache invalidation call back function
  // used for driver to update PD cache
  private val callBackFunc = CacheInvalidateListener.getInstance()

  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] =
    new Iterator[ColumnarBatch] {

      private val tiPartition = split.asInstanceOf[TiPartition]
      private val session = TiSession.getInstance(tiConf)
      private val snapshot = session.createSnapshot(dagRequest.getStartTs)
      private[this] val tasks = tiPartition.tasks

      private val iterator =
        snapshot.tableReadChunk(dagRequest, tasks, chunkBatchSize)

      override def hasNext: Boolean = {
        // Kill the task in case it has been marked as killed. This logic is from
        // Interrupted Iterator, but we inline it here instead of wrapping the iterator in order
        // to avoid performance overhead.
        if (context.isInterrupted()) {
          throw new TaskKilledException

      override def next(): ColumnarBatch = {

Example 35
Source File: RedisSourceRdd.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up

import com.redislabs.provider.redis.RedisConfig
import com.redislabs.provider.redis.util.ConnectionUtils.withConnection
import org.apache.spark.rdd.RDD
import org.apache.spark.{Partition, SparkContext, TaskContext}

class RedisSourceRdd(sc: SparkContext, redisConfig: RedisConfig,
                     offsetRanges: Seq[RedisSourceOffsetRange], autoAck: Boolean = true)
  extends RDD[StreamEntry](sc, Nil) {

  override def compute(split: Partition, context: TaskContext): Iterator[StreamEntry] = {
    val partition = split.asInstanceOf[RedisSourceRddPartition]
    val offsetRange = partition.offsetRange
    val streamReader = new RedisStreamReader(redisConfig)

  override protected def getPartitions: Array[Partition] = { { case (e, i) => RedisSourceRddPartition(i, e) }

case class RedisSourceRddPartition(index: Int, offsetRange: RedisSourceOffsetRange)
  extends Partition 
Example 36
Source File: EdgeRDDImpl.scala    From graphx-algorithm   with GNU General Public License v2.0 5 votes vote down vote up
package org.apache.spark.graphx.impl

import scala.reflect.{classTag, ClassTag}

import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext}
import org.apache.spark.rdd.RDD

import org.apache.spark.graphx._

class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
    @transient override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])],
    val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
  extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {

  override def setName(_name: String): this.type = {
    if ( != null) {
      partitionsRDD.setName( + ", " + _name)
    } else {

  override def count(): Long = { + _)

  override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] =
    mapEdgePartitions((pid, part) =>

  override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse)

  def filter(
      epred: EdgeTriplet[VD, ED] => Boolean,
      vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = {
    mapEdgePartitions((pid, part) => part.filter(epred, vpred))

  override def innerJoin[ED2: ClassTag, ED3: ClassTag]
      (other: EdgeRDD[ED2])
      (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = {
    val ed2Tag = classTag[ED2]
    val ed3Tag = classTag[ED3]
    this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
      (thisIter, otherIter) =>
        val (pid, thisEPart) =
        val (_, otherEPart) =
        Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))

  def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag](
      f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = {
    this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter =>
      if (iter.hasNext) {
        val (pid, ep) =
        Iterator(Tuple2(pid, f(pid, ep)))
      } else {
    }, preservesPartitioning = true))

  private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag](
      partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = {
    new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel)

  override private[graphx] def withTargetStorageLevel(
      targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = {
    new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel)

Example 37
Source File: MultiZippedPartitionRDD.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SparkContext, TaskContext}

import scala.reflect.ClassTag

private[spark] class MultiZippedPartitionsRDD[A: ClassTag, V: ClassTag](
    sc: SparkContext,
    var f: (List[Iterator[A]]) => Iterator[V],
    var rddList: List[RDD[A]],
    preservesPartitioning: Boolean = false)
  extends ZippedPartitionsBaseRDD[V](sc, rddList, preservesPartitioning) {

  override def compute(s: Partition, context: TaskContext): Iterator[V] = {
    val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
    val iterList ={ case (rdd: RDD[A], index: Int) =>
      rdd.iterator(partitions(index), context)

  override def clearDependencies() {
    rddList = null
    f = null
Example 38
Source File: MapJoinPartitionsRDDV2.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import org.apache.spark.serializer.Serializer
import org.apache.spark.{TaskContext, _}
import org.apache.spark.util.Utils

import scala.reflect.ClassTag

class MapJoinPartitionsPartitionV2(
    idx: Int,
    @transient private val rdd1: RDD[_],
    @transient private val rdd2: RDD[_],
    s2IdxArr: Array[Int]) extends Partition {

  var s1 = rdd1.partitions(idx)
  var s2Arr = => rdd2.partitions(s2Idx))
  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    s1 = rdd1.partitions(idx)
    s2Arr = => rdd2.partitions(s2Idx))

class MapJoinPartitionsRDDV2[A: ClassTag, B: ClassTag, V: ClassTag](
    sc: SparkContext,
    var idxF: (Int) => Array[Int],
    var f: (Int, Iterator[A], Array[(Int, Iterator[B])]) => Iterator[V],
    var rdd1: RDD[A],
    var rdd2: RDD[B],
    preservesPartitioning: Boolean = false)
  extends RDD[V](sc, Nil) {

  var rdd2WithPid = rdd2.mapPartitionsWithIndex((pid, iter) => => (pid, x)))

  private val serializer: Serializer = SparkEnv.get.serializer

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdd1.partitions.length)
    for (s1 <- rdd1.partitions) {
      val idx = s1.index
      array(idx) = new MapJoinPartitionsPartitionV2(idx, rdd1, rdd2, idxF(idx))

  override def getDependencies: Seq[Dependency[_]] = List(
    new OneToOneDependency(rdd1),
    new ShuffleDependency[Int, B, B](
      rdd2WithPid.asInstanceOf[RDD[_ <: Product2[Int, B]]],
      new IdentityPartitioner(rdd2WithPid.getNumPartitions), serializer)

  override def getPreferredLocations(s: Partition): Seq[String] = {
    val fp = firstParent[A]
    // println(s"pref loc: ${fp.preferredLocations(fp.partitions(s.index))}")

  override def compute(split: Partition, context: TaskContext): Iterator[V] = {
    val currSplit = split.asInstanceOf[MapJoinPartitionsPartitionV2]
    val rdd2Dep = dependencies(1).asInstanceOf[ShuffleDependency[Int, Any, Any]]
    val rdd2PartIter = => (s2.index,
        .getReader[Int, B](rdd2Dep.shuffleHandle, s2.index, s2.index + 1, context)
        .read().map(x => x._2)
    val rdd1Iter = rdd1.iterator(currSplit.s1, context)
    f(currSplit.s1.index, rdd1Iter, rdd2PartIter)

  override def clearDependencies() {
    rdd1 = null
    rdd2 = null
    rdd2WithPid = null
    idxF = null
    f = null

private[spark] class IdentityPartitioner(val numParts: Int) extends Partitioner {
  require(numPartitions > 0)
  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
  override def numPartitions: Int = numParts
Example 39
Source File: CommitFailureTestSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          path: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(path, dataSchema, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true

          override def write(row: InternalRow): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")


          override def close(): Unit = {
            sys.error("Intentional task commitment failure for testing purpose.")

      override def getFileExtension(context: TaskAttemptContext): String = ""

  override def shortName(): String = "commit-failure-test"
Example 40
Source File: ReadOnlySQLConf.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.internal

import java.util.{Map => JMap}

import org.apache.spark.TaskContext
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}

class ReadOnlySQLConf(context: TaskContext) extends SQLConf {

  @transient override val settings: JMap[String, String] = {
    context.getLocalProperties.asInstanceOf[JMap[String, String]]

  @transient override protected val reader: ConfigReader = {
    new ConfigReader(new TaskContextConfigProvider(context))

  override protected def setConfWithCheck(key: String, value: String): Unit = {
    throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")

  override def unsetConf(key: String): Unit = {
    throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")

  override def unsetConf(entry: ConfigEntry[_]): Unit = {
    throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")

  override def clear(): Unit = {
    throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")

  override def clone(): SQLConf = {
    throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")

  override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
    throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")

class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
  override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
Example 41
Source File: ObjectAggregationMap.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.aggregate

import java.{util => ju}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.config
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

  def dumpToExternalSorter(
      groupingAttributes: Seq[Attribute],
      aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
    val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
    val sorter = new UnsafeKVExternalSorter(

    val mapIterator = iterator
    val unsafeAggBufferProjection =

    while (mapIterator.hasNext) {
      val entry =
      aggregateFunctions.foreach {
        case agg: TypedImperativeAggregate[_] =>
        case _ =>



  def clear(): Unit = {

// Stores the grouping key and aggregation buffer
class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow) 
Example 42
Source File: ShuffledHashJoinExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class ShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

  override def requiredChildDistribution: Seq[Distribution] =
    HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil

  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
    buildTime += (System.nanoTime() - start) / 1000000
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener[Unit](_ => relation.close())

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows)
Example 43
Source File: CodecStreams.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import{InputStream, OutputStream, OutputStreamWriter}
import java.nio.charset.{Charset, StandardCharsets}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext

object CodecStreams {
  private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
    val compressionCodecs = new CompressionCodecFactory(config)

  def createInputStream(config: Configuration, file: Path): InputStream = {
    val fs = file.getFileSystem(config)
    val inputStream: InputStream =

    getDecompressionCodec(config, file)
      .map(codec => codec.createInputStream(inputStream))

  def getCompressionExtension(context: JobContext): String = {
Example 44
Source File: DataSourceRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.InputPartition

class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
  extends Partition with Serializable

class DataSourceRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val inputPartitions: Seq[InputPartition[T]])
  extends RDD[T](sc, Nil) {

  override protected def getPartitions: Array[Partition] = { {
      case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition)

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
    context.addTaskCompletionListener[Unit](_ => reader.close())
    val iter = new Iterator[T] {
      private[this] var valuePrepared = false

      override def hasNext: Boolean = {
        if (!valuePrepared) {
          valuePrepared =

      override def next(): T = {
        if (!hasNext) {
          throw new java.util.NoSuchElementException("End of stream")
        valuePrepared = false
    new InterruptibleIterator(context, iter)

  override def getPreferredLocations(split: Partition): Seq[String] = {
Example 45
Source File: BasicWriteStatsTracker.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources


import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.SerializableConfiguration

class BasicWriteJobStatsTracker(
    serializableHadoopConf: SerializableConfiguration,
    @transient val metrics: Map[String, SQLMetric])
  extends WriteJobStatsTracker {

  override def newTaskInstance(): WriteTaskStatsTracker = {
    new BasicWriteTaskStatsTracker(serializableHadoopConf.value)

  override def processStats(stats: Seq[WriteTaskStats]): Unit = {
    val sparkContext = SparkContext.getActive.get
    var numPartitions: Long = 0L
    var numFiles: Long = 0L
    var totalNumBytes: Long = 0L
    var totalNumOutput: Long = 0L

    val basicStats =[BasicWriteTaskStats])

    basicStats.foreach { summary =>
      numPartitions += summary.numPartitions
      numFiles += summary.numFiles
      totalNumBytes += summary.numBytes
      totalNumOutput += summary.numRows


    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
    SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList)

object BasicWriteJobStatsTracker {
  private val NUM_FILES_KEY = "numFiles"
  private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes"
  private val NUM_OUTPUT_ROWS_KEY = "numOutputRows"
  private val NUM_PARTS_KEY = "numParts"

  def metrics: Map[String, SQLMetric] = {
    val sparkContext = SparkContext.getActive.get
      NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"),
      NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
      NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
      NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
Example 46
Source File: EvalPythonExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python


import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends SparkPlan {

  def children: Seq[SparkPlan] = child :: Nil

  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
    udf.children match {
      case Seq(u: PythonUDF) =>
        val (chained, children) = collectFunctions(u)
        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
      case children =>
        // There should not be any other UDFs, or the children can't be evaluated directly.
        (ChainedPythonFunctions(Seq(udf.func)), udf.children)

  protected def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow]

  protected override def doExecute(): RDD[InternalRow] = {
    val inputRDD = child.execute().map(_.copy())

    inputRDD.mapPartitions { iter =>
      val context = TaskContext.get()

      // The queue used to buffer input rows so we can drain it to
      // combine input with output from Python.
      val queue = HybridRowQueue(context.taskMemoryManager(),
        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
      context.addTaskCompletionListener[Unit] { ctx =>

      val (pyFuncs, inputs) =

      // flatten all the arguments
      val allInputs = new ArrayBuffer[Expression]
      val dataTypes = new ArrayBuffer[DataType]
      val argOffsets = { input => { e =>
          if (allInputs.exists(_.semanticEquals(e))) {
          } else {
            allInputs += e
            dataTypes += e.dataType
            allInputs.length - 1
      val projection = newMutableProjection(allInputs, child.output)
      val schema = StructType( { case (dt, i) =>
        StructField(s"_$i", dt)

      // Add rows to queue to join later with the result.
      val projectedRowIter = { inputRow =>

      val outputRowIterator = evaluate(
        pyFuncs, argOffsets, projectedRowIter, schema, context)

      val joined = new JoinedRow
      val resultProj = UnsafeProjection.create(output, output) { outputRow =>
        resultProj(joined(queue.remove(), outputRow))
Example 47
Source File: ArrowEvalPythonExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType

case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends EvalPythonExec(udfs, output, child) {

  private val batchSize = conf.arrowMaxRecordsPerBatch
  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)

  protected override def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow] = {

    val outputTypes = output.drop(child.output.length).map(_.dataType)

    // DO NOT use iter.grouped(). See BatchIterator.
    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)

    val columnarBatchIter = new ArrowPythonRunner(
      pythonRunnerConf).compute(batchIter, context.partitionId(), context)

    new Iterator[InternalRow] {

      private var currentIter = if (columnarBatchIter.hasNext) {
        val batch =
        val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
        assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
          s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
      } else {

      override def hasNext: Boolean = currentIter.hasNext || {
        if (columnarBatchIter.hasNext) {
          currentIter =
        } else {

      override def next(): InternalRow =
Example 48
Source File: BatchEvalPythonExec.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends EvalPythonExec(udfs, output, child) {

  protected override def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow] = {
    EvaluatePython.registerPicklers()  // register pickler for Row

    val dataTypes =
    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

    // enable memo iff we serialize the row with schema (schema and class should be memorized)
    val pickle = new Pickler(needConversion)
    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
    // For each row, add it to the queue.
    val inputIterator = { row =>
      if (needConversion) {
        EvaluatePython.toJava(row, schema)
      } else {
        // fast path for these types that does not need conversion in Python
        val fields = new Array[Any](row.numFields)
        var i = 0
        while (i < row.numFields) {
          val dt = dataTypes(i)
          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
          i += 1
    }.grouped(100).map(x => pickle.dumps(x.toArray))

    // Output iterator for results from Python.
    val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
      .compute(inputIterator, context.partitionId(), context)

    val unpickle = new Unpickler
    val mutableRow = new GenericInternalRow(1)
    val resultType = if (udfs.length == 1) {
    } else {
      StructType( => StructField("", u.dataType, u.nullable)))

    val fromJava = EvaluatePython.makeFromJava(resultType)

    outputIterator.flatMap { pickedResult =>
      val unpickledBatch = unpickle.loads(pickedResult)
    }.map { result =>
      if (udfs.length == 1) {
        // fast path for single UDF
        mutableRow(0) = fromJava(result)
      } else {
Example 49
Source File: StateStoreRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import java.util.UUID

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.EpochTracker
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val stateStoreProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),

    // If we're in continuous processing mode, we should get the store version for the current
    // epoch rather than the one at planning time.
    val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING))
    val currentVersion = if (isContinuous) {
      val epoch = EpochTracker.getCurrentEpoch
      assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.")
    } else {

    store = StateStore.get(
      storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion,
      storeConf, hadoopConfBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
Example 50
Source File: package.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

package object state {

  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {

    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
        stateInfo: StatefulOperatorStateInfo,
        keySchema: StructType,
        valueSchema: StructType,
        indexOrdinal: Option[Int],
        sessionState: SessionState,
        storeCoordinator: Option[StateStoreCoordinatorRef])(
        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {

      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
        // Abort the state store in case of error
        TaskContext.get().addTaskCompletionListener[Unit](_ => {
          if (!store.hasCommitted) store.abort()
        cleanedF(store, iter)

      new StateStoreRDD(
Example 51
Source File: ContinuousWriteRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
import org.apache.spark.util.Utils

class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
    extends RDD[Unit](prev) {

  override val partitioner = prev.partitioner

  override def getPartitions: Array[Partition] = prev.partitions

  override def compute(split: Partition, context: TaskContext): Iterator[Unit] = {
    val epochCoordinator = EpochCoordinatorRef.get(
    while (!context.isInterrupted() && !context.isCompleted()) {
      var dataWriter: DataWriter[InternalRow] = null
      // write the data and commit this writer.
      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
        try {
          val dataIterator = prev.compute(split, context)
          dataWriter = writeTask.createDataWriter(
          while (dataIterator.hasNext) {
          logInfo(s"Writer for partition ${context.partitionId()} " +
            s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
          val msg = dataWriter.commit()
          logInfo(s"Writer for partition ${context.partitionId()} " +
            s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.")
        } catch {
          case _: InterruptedException =>
          // Continuous shutdown always involves an interrupt. Just finish the task.
      })(catchBlock = {
        // If there is an error, abort this writer. We enter this callback in the middle of
        // rethrowing an exception, so compute() will stop executing at this point.
        logError(s"Writer for partition ${context.partitionId()} is aborting.")
        if (dataWriter != null) dataWriter.abort()
        logError(s"Writer for partition ${context.partitionId()} aborted.")


  override def clearDependencies() {
    prev = null
Example 52
Source File: ContinuousShuffleReadRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous.shuffle

import java.util.UUID

import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.NextIterator

case class ContinuousShuffleReadPartition(
      index: Int,
      endpointName: String,
      queueSize: Int,
      numShuffleWriters: Int,
      epochIntervalMs: Long)
    extends Partition {
  // Initialized only on the executor, and only once even as we call compute() multiple times.
  lazy val (reader: ContinuousShuffleReader, endpoint) = {
    val env = SparkEnv.get.rpcEnv
    val receiver = new RPCContinuousShuffleReader(
      queueSize, numShuffleWriters, epochIntervalMs, env)
    val endpoint = env.setupEndpoint(endpointName, receiver)

    TaskContext.get().addTaskCompletionListener[Unit] { ctx =>
    (receiver, endpoint)

class ContinuousShuffleReadRDD(
    sc: SparkContext,
    numPartitions: Int,
    queueSize: Int = 1024,
    numShuffleWriters: Int = 1,
    epochIntervalMs: Long = 1000,
    val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}"))
  extends RDD[UnsafeRow](sc, Nil) {

  override protected def getPartitions: Array[Partition] = {
    (0 until numPartitions).map { partIndex =>
        partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs)

  override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
Example 53
Source File: ReferenceSort.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryExecNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = child.outputPartitioning
Example 54
Source File: KinesisRDDWriter.scala    From aws-kinesis-scala   with Apache License 2.0 5 votes vote down vote up

import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.regions.Regions
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.TaskContext
import org.json4s.jackson.JsonMethods
import org.json4s.{DefaultFormats, Extraction, Formats}
import org.slf4j.LoggerFactory

class KinesisRDDWriter[A <: AnyRef](streamName: String, region: Regions,
                                    credentials: SparkAWSCredentials,
                                    chunk: Int, endpoint: Option[String]) extends Serializable {
  private val logger = LoggerFactory.getLogger(getClass)

  def write(task: TaskContext, data: Iterator[A]): Unit = {
    // send data, including retry
    def put(a: Seq[PutRecordsEntry]) = => KinesisRDDWriter.endpointClient(credentials)(e)(region))
      .putRecordsWithRetry(PutRecordsRequest(streamName, a))
      .zipWithIndex.collect { case (Left(e), i) => a(i) -> s"${e.errorCode}: ${e.errorMessage}" }

    val errors = data.foldLeft(
      (Nil: Seq[PutRecordsEntry], Nil: Seq[(PutRecordsEntry, String)])
    ){ (z, x) =>
      val (records, failed) = z
      val payload = serialize(x)
      val entry   = PutRecordsEntry(DigestUtils.sha256Hex(payload), payload)

      // record exceeds max size
      if (entry.recordSize > recordMaxDataSize)
        records -> ((entry -> "per-record size limit") +: failed)

      // execute
      else if (records.size >= chunk || ( + entry.recordSize) >= recordsMaxDataSize)
        (entry +: Nil) -> (put(records) ++ failed)

      // buffering
        (entry +: records) -> failed
    } match {
      case (Nil, e)  => e
      case (rest, e) => put(rest) ++ e

    // failed records
    if (errors.nonEmpty) dump(errors)

  protected def dump(errors: Seq[(PutRecordsEntry, String)]): Unit =
      s"""Could not put record, count: ${errors.size}, following details:
         |${errors map { case (entry, message) => message + "\n" + new String(, "UTF-8") } mkString "\n"}

  protected def serialize(a: A)(implicit formats: Formats = DefaultFormats): Array[Byte] =


object KinesisRDDWriter {
  private val cache = collection.concurrent.TrieMap.empty[Regions, AmazonKinesis]

  private val client: SparkAWSCredentials => Regions => AmazonKinesis = {
    credentials => implicit region =>
      cache.getOrElseUpdate(region, AmazonKinesis(credentials.provider))

  private val endpointClient: SparkAWSCredentials => String => Regions => AmazonKinesis = {
    credentials => endpoint => implicit region =>
      cache.getOrElseUpdate(region, AmazonKinesis(credentials.provider, new EndpointConfiguration(endpoint, region.getName)))

Example 55
Source File: EventHubsRDD.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.eventhubs.rdd

import org.apache.spark.eventhubs.EventHubsConf
import org.apache.spark.eventhubs.client.CachedEventHubsReceiver
import org.apache.spark.eventhubs.utils.SimulatedCachedReceiver
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.{ Partition, SparkContext, TaskContext }

private[spark] class EventHubsRDD(sc: SparkContext,
                                  val ehConf: EventHubsConf,
                                  val offsetRanges: Array[OffsetRange])
    extends RDD[EventData](sc, Nil)
    with Logging
    with HasOffsetRanges {

  override def getPartitions: Array[Partition] =
        o =>
          new EventHubsRDDPartition(

  override def count: Long =

  override def isEmpty(): Boolean = count == 0L

  override def take(num: Int): Array[EventData] = {
    val nonEmptyPartitions =[EventHubsRDDPartition]).filter(_.count > 0)

    if (num < 1 || nonEmptyPartitions.isEmpty) {
      return Array()

    val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
      val remain = num - result.values.sum
      if (remain > 0) {
        val taken = Math.min(remain, part.count)
        result + (part.index -> taken.toInt)

      } else {

        (tc: TaskContext, it: Iterator[EventData]) => it.take(parts(tc.partitionId)).toArray,

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val part = split.asInstanceOf[EventHubsRDDPartition]

  private def errBeginAfterEnd(part: EventHubsRDDPartition): String =
    s"The beginning sequence number ${part.fromSeqNo} is larger than the ending " +
      s"sequence number ${part.untilSeqNo} for EventHubs ${} on partition " +

  override def compute(partition: Partition, context: TaskContext): Iterator[EventData] = {
    val part = partition.asInstanceOf[EventHubsRDDPartition]
    assert(part.fromSeqNo <= part.untilSeqNo, errBeginAfterEnd(part))

    if (part.fromSeqNo == part.untilSeqNo) {
        s"(TID ${context.taskAttemptId()}) Beginning sequence number ${part.fromSeqNo} is equal to the ending sequence " +
          s"number ${part.untilSeqNo}. Returning empty partition for EH: ${} " +
          s"on partition: ${part.partitionId}")
    } else {
        s"(TID ${context.taskAttemptId()}) Computing EventHubs ${}, partition ${part.partitionId} " +
          s"sequence numbers ${part.fromSeqNo} => ${part.untilSeqNo}")
      val cachedReceiver = if (ehConf.useSimulatedClient) {
      } else {
                             (part.untilSeqNo - part.fromSeqNo).toInt)
Example 56
Source File: SplashShuffleWriter.scala    From splash   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle

import org.apache.spark.TaskContext
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus

  override def stop(success: Boolean): Option[MapStatus] = {
    try {
      if (stopping) {
      } else {
        stopping = true
        if (success) {
          Option(MapStatus(resolver.blockManagerId, partitionLengths))
        } else {
    } finally {
      if (sorter != null) {
        val startTime = System.nanoTime
        writeMetrics.incWriteTime(System.nanoTime - startTime)
        sorter = null
Example 57
Source File: SplashShuffleReader.scala    From splash   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle

import org.apache.spark.internal.Logging
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}

private[spark] class SplashShuffleReader[K, V](
    resolver: SplashShuffleBlockResolver,
    handle: BaseShuffleHandle[K, _, V],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext,
    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
    extends ShuffleReader[K, V] with Logging {

  private val dep = handle.dependency

  private type Pair = (Any, Any)
  private type KCPair = (K, V)
  private type KCIterator = Iterator[KCPair]

  override def read(): KCIterator = {
    val shuffleBlocks = mapOutputTracker.getMapSizesByExecutorId(
      handle.shuffleId, startPartition, endPartition)

  def readShuffleBlocks(shuffleBlocks: Seq[(BlockId, Long)]): KCIterator =

  def readShuffleBlocks(shuffleBlocks: Iterator[(BlockId, Long)]): KCIterator = {
    val taskMetrics = context.taskMetrics()
    val serializer = SplashSerializer(dep)

    val nonEmptyBlocks = shuffleBlocks.filter(_._2 > 0).map(_._1)
    val fetcherIterator = SplashShuffleFetcherIterator(resolver, nonEmptyBlocks)

    def getAggregatedIterator(iterator: Iterator[Pair]): KCIterator = {
      dep.aggregator match {
        case Some(agg) =>
          val aggregator = new SplashAggregator(agg)
          if (dep.mapSideCombine) {
            // We are reading values that are already combined
            val combinedKeyValuesIterator = iterator.asInstanceOf[Iterator[(K, V)]]
            aggregator.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // We don't know the value type, but also don't care -- the dependency *should*
            // have made sure its compatible w/ this aggregator, which will convert the value
            // type to the combined type C
            val keyValuesIterator = iterator.asInstanceOf[Iterator[(K, Nothing)]]
            aggregator.combineValuesByKey(keyValuesIterator, context)
        case None =>
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

    def getSortedIterator(iterator: KCIterator): KCIterator = {
      // Sort the output if there is a sort ordering defined.
      dep.keyOrdering match {
        case Some(keyOrd: Ordering[K]) =>
          // Create an ExternalSorter to sort the data.
          val sorter = new SplashSorter[K, V, V](
            context, ordering = Some(keyOrd), serializer = serializer)
        case None =>

    val metricIter = fetcherIterator.flatMap(
      _.asMetricIterator(serializer, taskMetrics))

    // An interruptible iterator must be used here in order to support task cancellation
        new InterruptibleIterator[Pair](context, metricIter)))
Example 58
Source File: SplashAggregator.scala    From splash   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle

import org.apache.spark.{Aggregator, TaskContext}

class SplashAggregator[K, V, C](
    agg: Aggregator[K, V, C])
    extends Aggregator[K, V, C](
      agg.createCombiner, agg.mergeValue, agg.mergeCombiners) {

  override def combineValuesByKey(
      iter: Iterator[_ <: Product2[K, V]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new SplashAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
    updateMetrics(context, combiners)

  override def combineCombinersByKey(
      iter: Iterator[_ <: Product2[K, C]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new SplashAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
    updateMetrics(context, combiners)

  private def updateMetrics(context: TaskContext, map: SplashAppendOnlyMap[_, _, _]): Unit = {
    Option(context).foreach { c =>
Example 59
Source File: CloseableColumnBatchIterator.scala    From OAP   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.internal.Logging
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.TaskContext

class CloseableColumnBatchIterator(itr: Iterator[ColumnarBatch])
    extends Iterator[ColumnarBatch]
    with Logging {
  var cb: ColumnarBatch = null

  private def closeCurrentBatch(): Unit = {
    if (cb != null) {
      //logInfo(s"${itr} close ${cb}.")
      cb = null

    .addTaskCompletionListener[Unit]((tc: TaskContext) => {

  override def hasNext: Boolean = {

  override def next(): ColumnarBatch = {
    cb =
Example 60
Source File: ColumnarShuffledHashJoinExec.scala    From OAP   with Apache License 2.0 5 votes vote down vote up

import java.util.concurrent.TimeUnit._


import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import scala.collection.mutable.ListBuffer
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.gandiva.evaluator._

import io.netty.buffer.ArrowBuf

import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}

class ColumnarShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan) extends ShuffledHashJoinExec(
    right) {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "joinTime" -> SQLMetrics.createTimingMetric(sparkContext, "join time"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

  override def supportsColumnar = true

  //TODO() Disable code generation
  //override def supportCodegen: Boolean = false

  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
    val numOutputRows = longMetric("numOutputRows")
    val joinTime = longMetric("joinTime")
    val buildTime = longMetric("buildTime")
    val resultSchema = this.schema
    streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) { (streamIter, buildIter) =>
      //val hashed = buildHashedRelation(buildIter)
      //join(streamIter, hashed, numOutputRows)
      val vjoin = ColumnarShuffledHashJoin.create(leftKeys, rightKeys, resultSchema, joinType, buildSide, condition, left, right, buildTime, joinTime, numOutputRows)
      val vjoinResult = vjoin.columnarInnerJoin(streamIter, buildIter)
      TaskContext.get().addTaskCompletionListener[Unit](_ => {
      new CloseableColumnBatchIterator(vjoinResult)
Example 61
Source File: ColumnarSortExec.scala    From OAP   with Apache License 2.0 5 votes vote down vote up


import java.util.concurrent.TimeUnit._

import org.apache.spark.{SparkEnv, TaskContext, SparkContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.sql.execution._
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

class ColumnarSortExec(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan,
    testSpillFrequency: Int = 0)
    extends SortExec(sortOrder, global, child, testSpillFrequency) {
  override def supportsColumnar = true

  // Disable code generation
  override def supportCodegen: Boolean = false

  override lazy val metrics = Map(
    "totalSortTime" -> SQLMetrics
      .createTimingMetric(sparkContext, "time in sort + shuffle process"),
    "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in sort process"),
    "shuffleTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in shuffle process"),
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"))

  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
    val elapse = longMetric("totalSortTime")
    val sortTime = longMetric("sortTime")
    val shuffleTime = longMetric("shuffleTime")
    val numOutputRows = longMetric("numOutputRows")
    val numOutputBatches = longMetric("numOutputBatches")
    child.executeColumnar().mapPartitions { iter =>
      val hasInput = iter.hasNext
      val res = if (!hasInput) {
      } else {
        val sorter = ColumnarSorter.create(
          .addTaskCompletionListener[Unit](_ => {
        new CloseableColumnBatchIterator(sorter.createColumnarIterator(iter))
Example 62
Source File: TaskContextImplAdapter.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.adapter

import java.util.Properties

import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem

object TaskContextImplAdapter {

  def createTaskContextImpl(
      stageId: Int,
      partitionId: Int,
      taskAttemptId: Long,
      attemptNumber: Int,
      taskMemoryManager: TaskMemoryManager,
      localProperties: Properties,
      metricsSystem: MetricsSystem): TaskContext = {
    new TaskContextImpl(
      stageAttemptNumber = 0,
Example 63
Source File: SequoiadbRDD.scala    From spark-sequoiadb   with Apache License 2.0 5 votes vote down vote up
package com.sequoiadb.spark.rdd

import org.apache.spark.SparkContext
import com.sequoiadb.spark.partitioner._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.Filter
import org.apache.spark.{Partition, TaskContext}
import org.bson.BSONObject
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.ArrayBuffer

  def apply (
    sc: SQLContext,
    config: SequoiadbConfig,
    partitioner: Option[SequoiadbPartitioner] = None,
    requiredColumns: Array[String] = Array(),
    filters: Array[Filter] = Array(),
    queryReturnType: Int = SequoiadbConfig.QUERYRETURNBSON,
    queryLimit: Long = -1) = {
    new SequoiadbRDD ( sc.sparkContext, config, partitioner,
      requiredColumns, filters, queryReturnType, queryLimit)
Example 64
Source File: SlidingRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
    "Window size and step must be greater than 0, " +
      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)
      .sliding(windowSize, step)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.length
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
    } else {
      val w1 = windowSize - 1
      // Get partition sizes and first w1 elements.
      val (sizes, heads) = parent.mapPartitions { iter =>
        val w1Array = iter.take(w1).toArray
        Iterator.single((w1Array.length + iter.length, w1Array))
      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
      var i = 0
      var cumSize = 0
      var partitionIndex = 0
      while (i < n) {
        val mod = cumSize % step
        val offset = if (mod == 0) 0 else step - mod
        val size = sizes(i)
        if (offset < size) {
          val tail = mutable.ListBuffer.empty[T]
          // Keep appending to the current tail until it has w1 elements.
          var j = i + 1
          while (j < n && tail.length < w1) {
            tail ++= heads(j).take(w1 - tail.length)
            j += 1
          if (sizes(i) + tail.length >= offset + windowSize) {
            partitions +=
              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
            partitionIndex += 1
        cumSize += size
        i += 1

  // TODO: Override methods such as aggregate, which only requires one Spark job.
Example 65
Source File: CommitFailureTestSource.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          path: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(path, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true

          override def write(row: Row): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")


          override def close(): Unit = {
            sys.error("Intentional task commitment failure for testing purpose.")

      override def getFileExtension(context: TaskAttemptContext): String = ""

  override def shortName(): String = "commit-failure-test"
Source File: MonotonicallyIncreasingID.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}

  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
Source File: ShuffledHashJoinExec.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class ShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

  override def requiredChildDistribution: Seq[Distribution] =
    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
    buildTime += (System.nanoTime() - start) / 1000000
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener(_ => relation.close())

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows)
Source File: StateStoreRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StateStoreRDD[T: ClassTag, U: ClassTag](
    dataRDD: RDD[T],
    storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
    checkpointLocation: String,
    operatorId: Long,
    storeVersion: Long,
    keySchema: StructType,
    valueSchema: StructType,
    sessionState: SessionState,
    @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
  extends RDD[U](dataRDD) {

  private val storeConf = new StateStoreConf(sessionState.conf)

  // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
  private val confBroadcast = dataRDD.context.broadcast(
    new SerializableConfiguration(sessionState.newHadoopConf()))

  override protected def getPartitions: Array[Partition] = dataRDD.partitions

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
    store = StateStore.get(
      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
Source File: ReferenceSort.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryExecNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = child.outputPartitioning
Source File: SparkHadoopMapRedUtil.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mapred


import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging

object SparkHadoopMapRedUtil extends Logging {
  def commitTask(
      committer: MapReduceOutputCommitter,
      mrTaskContext: MapReduceTaskAttemptContext,
      jobId: Int,
      splitId: Int): Unit = {

    val mrTaskAttemptID = mrTaskContext.getTaskAttemptID

    // Called after we have decided to commit
    def performCommit(): Unit = {
      try {
        logInfo(s"$mrTaskAttemptID: Committed")
      } catch {
        case cause: IOException =>
          logError(s"Error committing the output of task: $mrTaskAttemptID", cause)
          throw cause

    // First, check whether the task's output has already been committed by some other attempt
    if (committer.needsTaskCommit(mrTaskContext)) {
      val shouldCoordinateWithDriver: Boolean = {
        val sparkConf = SparkEnv.get.conf
        // We only need to coordinate with the driver if there are concurrent task attempts.
        // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)

      if (shouldCoordinateWithDriver) {
        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
        val taskAttemptNumber = TaskContext.get().attemptNumber()
        val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber)

        if (canCommit) {
        } else {
          val message =
            s"$mrTaskAttemptID: Not committed because the driver did not authorize commit"
          // We need to abort the task so that the driver can reschedule new attempts, if necessary
          throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber)
      } else {
        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
    } else {
      // Some other attempt committed the output, so we do nothing and signal success
      logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
Source File: taskListeners.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.EventListener

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi

class TaskCompletionListenerException(
    errorMessages: Seq[String],
    val previousError: Option[Throwable] = None)
  extends RuntimeException {

  override def getMessage: String = {
    if (errorMessages.size == 1) {
    } else { { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
    } + { e =>
      "\n\nPrevious exception in task: " + e.getMessage + "\n" +
        e.getStackTrace.mkString("\t", "\n\t", "")
Source File: ZippedWithIndexRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    val parentIter = firstParent[T].iterator(split.prev, context)
    Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
Source File: UnionRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
    } else {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Source File: PartitionwiseSampledRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    preservesPartitioning: Boolean,
    @transient private val seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: PartitionerAwareUnionRDD.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map { index =>
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) =>
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Source File: MemoryTestingUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.memory

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}

object MemoryTestingUtils {
  def fakeTaskContext(env: SparkEnv): TaskContext = {
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
    new TaskContextImpl(
      stageId = 0,
      partitionId = 0,
      taskAttemptId = 0,
      attemptNumber = 0,
      taskMemoryManager = taskMemoryManager,
      localProperties = new Properties,
      metricsSystem = env.metricsSystem)
Source File: FakeTask.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(
    stageId: Int,
    partitionId: Int,
    prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) {
  override def runTask(context: TaskContext): Int = 0
  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
  TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Seconds, Span}

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Example 79
Source File: PartitionPruningRDDSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Source File: UberXGBoostModel.scala    From uberdata   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sparkts.models

import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.spark.{XGBoost, XGBoostModel}
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD

import scala.collection.JavaConverters._

object UberXGBoostModel {
  def train(trainLabel: RDD[LabeledPoint],
            configMap: Map[String, Any],
            round: Int,
            nWorkers: Int): XGBoostModel = {
    val trainData = trainLabel.cache
    XGBoost.trainWithRDD(trainData, configMap, round, nWorkers,useExternalMemory = true, missing
      = Float.NaN)

  def labelPredict(testSet: RDD[XGBLabeledPoint],
                   useExternalCache: Boolean,
                   booster: XGBoostModel): RDD[(Float, Float)] = {
    val broadcastBooster = testSet.sparkContext.broadcast(booster)
    testSet.mapPartitions { testData =>
      val (toPredict, toLabel) = testData.duplicate
      val dMatrix = new DMatrix(toPredict)
      val prediction = broadcastBooster.value.booster.predict(dMatrix).flatten.toIterator

  def labelPredict(testSet: RDD[DenseVector],
                   booster: XGBoostModel): RDD[(Float, Float)] = {
    val broadcastBooster = testSet.sparkContext.broadcast(booster)
    val rdd = testSet.cache
    broadcastBooster.value.predict(testSet,missingValue = Float.NaN).map(value => (value(0),
//    testSet.
//    testSet.mapPartitions { testData =>
//      val (toPredict, toLabel) = testData.duplicate
//      val dMatrix = new DMatrix(toPredict)
//      val prediction = broadcastBooster.value.booster.predict(dMatrix).flatten.toIterator
//    }
Source File: HashShuffleReader.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext)
  extends ShuffleReader[K, C]
  require(endPartition == startPartition + 1,
    "Hash shuffle currently only supports fetching one partition")

  private val dep = handle.dependency

  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
      case None =>
Source File: SortShuffleWriter.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.sort

import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockManager, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
    shuffleBlockManager: IndexShuffleBlockManager,
    handle: BaseShuffleHandle[K, V, C],
    mapId: Int,
    context: TaskContext)
  extends ShuffleWriter[K, V] with Logging {

  private val dep = handle.dependency

  private val blockManager = SparkEnv.get.blockManager

  private var sorter: ExternalSorter[K, V, _] = null

  // Are we in the process of stopping? Because map tasks can call stop() with success = true
  // and then call stop() with success = false if they get an exception, we want to make sure
  // we don't try deleting files, etc twice.
  private var stopping = false

  private var mapStatus: MapStatus = null

  private val writeMetrics = new ShuffleWriteMetrics()
  context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)

  override def stop(success: Boolean): Option[MapStatus] = {
    try {
      if (stopping) {
        return None
      stopping = true
      if (success) {
        return Option(mapStatus)
      } else {
        // The map task failed, so delete our output data.
        shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId)
        return None
    } finally {
      // Clean up our sorter, which may have its own intermediate files
      if (sorter != null) {
        sorter = null
Source File: ActiveJob.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.TaskContext
import org.apache.spark.util.CallSite

private[spark] class ActiveJob(
    val jobId: Int,
    val finalStage: Stage,
    val func: (TaskContext, Iterator[_]) => _,
    val partitions: Array[Int],
    val callSite: CallSite,
    val listener: JobListener,
    val properties: Properties) {

  val numPartitions = partitions.length
  val finished = Array.fill[Boolean](numPartitions)(false)
  var numFinished = 0
Source File: SampledRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.{Partition, TaskContext}

@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
  override val index: Int = prev.index

@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
private[spark] class SampledRDD[T: ClassTag](
    prev: RDD[T],
    withReplacement: Boolean,
    frac: Double,
    seed: Int)
  extends RDD[T](prev) {

  override def getPartitions: Array[Partition] = {
    val rg = new Random(seed)
    firstParent[T] => new SampledRDDPartition(x, rg.nextInt))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
    val split = splitIn.asInstanceOf[SampledRDDPartition]
    if (withReplacement) {
      // For large datasets, the expected number of occurrences of each element in a sample with
      // replacement is Poisson(frac). We use that to get a count for each element.
      val poisson = new PoissonDistribution(frac)

      firstParent[T].iterator(split.prev, context).flatMap { element =>
        val count = poisson.sample()
        if (count == 0) {
          Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
        } else {
    } else { // Sampling without replacement
      val rand = new Random(split.seed)
      firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
Source File: SubtractedRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.{HashMap => JHashMap}

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.Dependency
import org.apache.spark.OneToOneDependency
import org.apache.spark.Partition
import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.serializer.Serializer

  def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
    this.serializer = Option(serializer)

  override def getDependencies: Seq[Dependency[_]] = {
    Seq(rdd1, rdd2).map { rdd =>
      if (rdd.partitioner == Some(part)) {
        logDebug("Adding one-to-one dependency with " + rdd)
        new OneToOneDependency(rdd)
      } else {
        logDebug("Adding shuffle dependency with " + rdd)
        new ShuffleDependency(rdd, part, serializer)

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](part.numPartitions)
    for (i <- 0 until array.size) {
      // Each CoGroupPartition will depend on rdd1 and rdd2
      array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2) { case (rdd, j) =>
        dependencies(j) match {
          case s: ShuffleDependency[_, _, _] =>
            new ShuffleCoGroupSplitDep(s.shuffleHandle)
          case _ =>
            new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))

  override val partitioner = Some(part)

  override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
    val partition = p.asInstanceOf[CoGroupPartition]
    val map = new JHashMap[K, ArrayBuffer[V]]
    def getSeq(k: K): ArrayBuffer[V] = {
      val seq = map.get(k)
      if (seq != null) {
      } else {
        val seq = new ArrayBuffer[V]()
        map.put(k, seq)
    def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
      case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
        rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

      case ShuffleCoGroupSplitDep(handle) =>
        val iter = SparkEnv.get.shuffleManager
          .getReader(handle, partition.index, partition.index + 1, context)
    // the first dep is rdd1; add all values to the map
    integrate(partition.deps(0), t => getSeq(t._1) += t._2)
    // the second dep is rdd2; remove all of its keys
    integrate(partition.deps(1), t => map.remove(t._1)) { t => { (t._1, _) } }.flatten

  override def clearDependencies() {
    rdd1 = null
    rdd2 = null

Source File: ZippedWithIndexRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.size
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1, // do not need to count the last partition
        allowLocal = false
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    firstParent[T].iterator(split.prev, context) { x =>
      (x._1, split.startIndex + x._2)
Source File: UnionRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient rdd: RDD[T],
    val parentRddIndex: Int,
    @transient parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations() = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size)
      pos += rdd.partitions.size

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Source File: PartitionwiseSampledRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    @transient preservesPartitioning: Boolean,
    @transient seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: PartitionerAwareUnionRDD.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.length > 0)
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map(index => {
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) => {
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else  {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Source File: NotSerializableFakeTask.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import{ObjectInputStream, ObjectOutputStream, IOException}

import org.apache.spark.TaskContext

private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) {
  override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
  override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()

  private def writeObject(out: ObjectOutputStream): Unit = {
    if (stageId == 0) {
      throw new IllegalStateException("Cannot serialize")

  private def readObject(in: ObjectInputStream): Unit = {}
Source File: FakeTask.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
  override def runTask(context: TaskContext): Int = 0

  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, 0, 0, 0, null)
Source File: PartitionPruningRDDSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.scalatest.FunSuite

import org.apache.spark.{Partition, SharedSparkContext, TaskContext}

class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index = i

  def testValue = this.value

Source File: CarbonTaskCompletionListener.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.carbondata.execution.datasources.tasklisteners

import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.RecordReaderIterator
import org.apache.spark.util.TaskCompletionListener

import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.memory.UnsafeMemoryManager
import org.apache.carbondata.core.util.{DataTypeUtil, ThreadLocalTaskInfo}
import org.apache.carbondata.hadoop.internal.ObjectArrayWritable

trait CarbonCompactionTaskCompletionListener extends TaskCompletionListener

case class CarbonQueryTaskCompletionListenerImpl(iter: RecordReaderIterator[InternalRow],
    freeMemory: Boolean = false) extends CarbonQueryTaskCompletionListener {
  override def onTaskCompletion(context: TaskContext): Unit = {
    if (iter != null) {
      try {
      } catch {
        case e: Exception =>
    if (freeMemory) {

case class CarbonLoadTaskCompletionListenerImpl(recordWriter: RecordWriter[NullWritable,
    taskAttemptContext: TaskAttemptContext) extends CarbonLoadTaskCompletionListener {

  override def onTaskCompletion(context: TaskContext): Unit = {
    try {
    } finally {
Example 94
Source File: SparkUtil.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.{SPARK_VERSION, TaskContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ID_KEY

  def isSparkVersionXandAbove(xVersion: String, isEqualComparision: Boolean = false): Boolean = {
    val tmpArray = SPARK_VERSION.split("\\.")
    // convert to float
    val sparkVersion = if (tmpArray.length >= 2) {
      (tmpArray(0) + "." + tmpArray(1)).toFloat
    } else {
      (tmpArray(0) + ".0").toFloat
    // compare the versions
    if (isEqualComparision) {
      sparkVersion == xVersion.toFloat
    } else {
      sparkVersion >= xVersion.toFloat

  def isSparkVersionEqualTo(xVersion: String): Boolean = {
    isSparkVersionXandAbove(xVersion, true)

  def setNullExecutionId(sparkSession: SparkSession): Unit = {
    // " is already set" exception will be
    // thrown if not set to null in spark2.2 and below versions
    if (!SparkUtil.isSparkVersionXandAbove("2.3")) {
      sparkSession.sparkContext.setLocalProperty(EXECUTION_ID_KEY, null)

Source File: CarbonDropPartitionRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import java.util

import scala.collection.JavaConverters._

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.sql.SparkSession

import org.apache.carbondata.core.index.Segment
import org.apache.carbondata.core.indexstore.PartitionSpec
import org.apache.carbondata.core.metadata.SegmentFileStore

case class CarbonDropPartition(rddId: Int, idx: Int, segment: Segment)
  extends Partition {

  override val index: Int = idx

  override def hashCode(): Int = 41 * (41 + rddId) + idx

class CarbonDropPartitionRDD(
    @transient private val ss: SparkSession,
    tablePath: String,
    segments: Seq[Segment],
    partitions: util.List[PartitionSpec],
    uniqueId: String)
  extends CarbonRDD[(String, String)](ss, Nil) {

  override def internalGetPartitions: Array[Partition] = { {s =>
      CarbonDropPartition(id, s._2, s._1)

  override def internalCompute(
      theSplit: Partition,
      context: TaskContext): Iterator[(String, String)] = {
    val iter = new Iterator[(String, String)] {
      val split = theSplit.asInstanceOf[CarbonDropPartition]
      logInfo("Dropping partition information from : " + split.segment)
      val toBeDeletedSegments = new util.ArrayList[String]()
      val toBeUpdateSegments = new util.ArrayList[String]()
      new SegmentFileStore(

      var finished = false

      override def hasNext: Boolean = {

      override def next(): (String, String) = {
        finished = true
        (toBeUpdateSegments.asScala.mkString(","), toBeDeletedSegments.asScala.mkString(","))

Example 96
Source File: CompactionTaskCompletionListener.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import java.util

import org.apache.log4j.Logger
import org.apache.spark.TaskContext
import org.apache.spark.sql.carbondata.execution.datasources.tasklisteners.CarbonCompactionTaskCompletionListener
import org.apache.spark.util.CollectionAccumulator

import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.scan.result.iterator.RawResultIterator
import org.apache.carbondata.core.segmentmeta.SegmentMetaDataInfo
import org.apache.carbondata.processing.loading.TableProcessingOperations
import org.apache.carbondata.processing.loading.model.CarbonLoadModel
import org.apache.carbondata.processing.merger.{AbstractResultProcessor, CarbonCompactionExecutor, CarbonCompactionUtil}

class CompactionTaskCompletionListener(
    carbonLoadModel: CarbonLoadModel,
    exec: CarbonCompactionExecutor,
    processor: AbstractResultProcessor,
    rawResultIteratorMap: util.Map[String, util.List[RawResultIterator]],
    segmentMetaDataAccumulator: CollectionAccumulator[Map[String, SegmentMetaDataInfo]],
    queryStartTime: Long)
  extends CarbonCompactionTaskCompletionListener {

  val LOGGER: Logger = LogServiceFactory.getLogService(this.getClass.getName)

  override def onTaskCompletion(context: TaskContext): Unit = {
    // close all the query executor service and clean up memory acquired during query processing
    if (null != exec) {"Cleaning up query resources acquired during compaction")
      exec.close(rawResultIteratorMap.get(CarbonCompactionUtil.UNSORTED_IDX), queryStartTime)
      exec.close(rawResultIteratorMap.get(CarbonCompactionUtil.SORTED_IDX), queryStartTime)
    // clean up the resources for processor
    if (null != processor) {"Closing compaction processor instance to clean up loading resources")
    // fill segment metadata to accumulator

  private def deleteLocalDataFolders(): Unit = {
    try {"Deleting local folder store location")
      val isCompactionFlow = true
        .deleteLocalDataLoadFolderLocation(carbonLoadModel, isCompactionFlow, false)
    } catch {
      case e: Exception =>

Source File: CarbonRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.util.SparkSQLUtil

import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.metadata.schema.table.TableInfo
import org.apache.carbondata.core.util._

abstract class CarbonRDDWithTableInfo[T: ClassTag](
    @transient private val ss: SparkSession,
    @transient private var deps: Seq[Dependency[_]],
    serializedTableInfo: Array[Byte]) extends CarbonRDD[T](ss, deps) {

  def this(@transient sparkSession: SparkSession, @transient oneParent: RDD[_],
      serializedTableInfo: Array[Byte]) = {
    this (sparkSession, List(new OneToOneDependency(oneParent)), serializedTableInfo)

  def getTableInfo: TableInfo = TableInfo.deserialize(serializedTableInfo)
Source File: InsertTaskCompletionListener.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import org.apache.spark.TaskContext
import org.apache.spark.sql.carbondata.execution.datasources.tasklisteners.CarbonLoadTaskCompletionListener
import org.apache.spark.sql.execution.command.ExecutionErrors
import org.apache.spark.util.CollectionAccumulator

import org.apache.carbondata.core.segmentmeta.SegmentMetaDataInfo
import org.apache.carbondata.core.util.{DataTypeUtil, ThreadLocalTaskInfo}
import org.apache.carbondata.processing.loading.{DataLoadExecutor, FailureCauses}
import org.apache.carbondata.spark.util.CommonUtil

class InsertTaskCompletionListener(dataLoadExecutor: DataLoadExecutor,
    executorErrors: ExecutionErrors,
    segmentMetaDataAccumulator: CollectionAccumulator[Map[String, SegmentMetaDataInfo]],
    tableName: String,
    segmentId: String)
  extends CarbonLoadTaskCompletionListener {
  override def onTaskCompletion(context: TaskContext): Unit = {
    try {
      // fill segment metadata to accumulator
      if (null != dataLoadExecutor) {
    catch {
      case e: Exception =>
        if (null != executorErrors && executorErrors.failureCauses == FailureCauses.NONE) {
          // If already error happened before task completion,
          // that error need to be thrown. Not the new error. Hence skip this.
          throw e
    finally {
Source File: QueryTaskCompletionListener.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import scala.collection.JavaConverters._

import org.apache.hadoop.mapreduce.RecordReader
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.sql.carbondata.execution.datasources.tasklisteners.CarbonQueryTaskCompletionListener
import org.apache.spark.sql.profiler.{Profiler, QueryTaskEnd}

import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.memory.UnsafeMemoryManager
import org.apache.carbondata.core.stats.{QueryStatistic, QueryStatisticsConstants, QueryStatisticsRecorder}
import org.apache.carbondata.core.util.{DataTypeUtil, TaskMetricsMap, ThreadLocalTaskInfo}
import org.apache.carbondata.spark.InitInputMetrics

class QueryTaskCompletionListener(freeMemory: Boolean,
    var reader: RecordReader[Void, Object],
    inputMetricsStats: InitInputMetrics, executionId: String, taskId: Int, queryStartTime: Long,
    queryStatisticsRecorder: QueryStatisticsRecorder, split: Partition, queryId: String)
  extends CarbonQueryTaskCompletionListener {
  override def onTaskCompletion(context: TaskContext): Unit = {
    if (reader != null) {
      try {
      } catch {
        case e: Exception =>
      reader = null
    logStatistics(executionId, taskId, queryStartTime, queryStatisticsRecorder, split)
    if (freeMemory) {

  def logStatistics(
      executionId: String,
      taskId: Long,
      queryStartTime: Long,
      recorder: QueryStatisticsRecorder,
      split: Partition
  ): Unit = {
    if (null != recorder) {
      val queryStatistic = new QueryStatistic()
        System.currentTimeMillis - queryStartTime)
      // print executor query statistics for each task_id
      val statistics = recorder.statisticsForTask(taskId, queryStartTime)
      if (statistics != null && executionId != null) {
        Profiler.invokeIfEnable {
          val inputSplit = split.asInstanceOf[CarbonSparkPartition].split.value
          val size = inputSplit.getLength
          val files = { s =>
            s.getSegmentId + "/" + s.getPath.getName
Source File: UpdateDataLoad.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.rdd

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.apache.spark.util.CollectionAccumulator

import org.apache.carbondata.common.CarbonIterator
import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.segmentmeta.SegmentMetaDataInfo
import org.apache.carbondata.core.statusmanager.{LoadMetadataDetails, SegmentStatus}
import org.apache.carbondata.core.util.ThreadLocalTaskInfo
import org.apache.carbondata.processing.loading.{DataLoadExecutor, TableProcessingOperations}
import org.apache.carbondata.processing.loading.model.CarbonLoadModel
import org.apache.carbondata.spark.util.CommonUtil

object UpdateDataLoad {

  def DataLoadForUpdate(
      segId: String,
      index: Long,
      iter: Iterator[Row],
      carbonLoadModel: CarbonLoadModel,
      loadMetadataDetails: LoadMetadataDetails,
      segmentMetaDataAccumulator: CollectionAccumulator[Map[String, SegmentMetaDataInfo]]): Unit = {
    val LOGGER = LogServiceFactory.getLogService(this.getClass.getCanonicalName)
    try {
      val recordReaders = mutable.Buffer[CarbonIterator[Array[AnyRef]]]()
      recordReaders += new NewRddIterator(iter,

      val loader = new SparkPartitionLoader(carbonLoadModel,
      // Initialize to set carbon properties

      val executor = new DataLoadExecutor
      TaskContext.get().addTaskCompletionListener { context =>
        // fill segment metadata to accumulator

    } catch {
      case e: Exception =>
        throw e
    } finally {
      TableProcessingOperations.deleteLocalDataLoadFolderLocation(carbonLoadModel, false, false)

Source File: CarbonMergeBloomIndexFilesRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.index

import scala.collection.JavaConverters._

import org.apache.spark.Partition
import org.apache.spark.rdd.CarbonMergeFilePartition
import org.apache.spark.sql.SparkSession
import org.apache.spark.TaskContext

import org.apache.carbondata.core.metadata.schema.table.CarbonTable
import org.apache.carbondata.core.util.path.CarbonTablePath
import org.apache.carbondata.index.bloom.BloomIndexFileStore
import org.apache.carbondata.spark.rdd.CarbonRDD

class CarbonMergeBloomIndexFilesRDD(
  @transient private val ss: SparkSession,
  carbonTable: CarbonTable,
  segmentIds: Seq[String],
  bloomIndexNames: Seq[String],
  bloomIndexColumns: Seq[Seq[String]])
  extends CarbonRDD[String](ss, Nil) {

  override def internalGetPartitions: Array[Partition] = { {s =>
      CarbonMergeFilePartition(id, s._2, s._1)

  override def internalCompute(theSplit: Partition, context: TaskContext): Iterator[String] = {
    val tablePath = carbonTable.getTablePath
    val split = theSplit.asInstanceOf[CarbonMergeFilePartition]
    logInfo("Merging bloom index files of " +
      s"segment ${split.segmentId} for ${carbonTable.getTableName}") dm => {
      val dmSegmentPath = CarbonTablePath.getIndexesStorePath(
        tablePath, split.segmentId, dm._1)
      BloomIndexFileStore.mergeBloomIndexFile(dmSegmentPath, bloomIndexColumns(dm._2).asJava)

    val iter = new Iterator[String] {
      var havePair = false
      var finished = false

      override def hasNext: Boolean = {
        if (!finished && !havePair) {
          finished = true
          havePair = !finished

      override def next(): String = {
        if (!hasNext) {
          throw new java.util.NoSuchElementException("End of stream")
        havePair = false

Source File: SegmentPruneRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.indexserver

import scala.collection.JavaConverters._

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.sql.SparkSession

import org.apache.carbondata.core.cache.CacheProvider
import org.apache.carbondata.core.index.{IndexInputFormat, IndexStoreManager}
import org.apache.carbondata.core.indexstore.SegmentWrapper
import org.apache.carbondata.spark.rdd.CarbonRDD

class SegmentPruneRDD(@transient private val ss: SparkSession,
    indexInputFormat: IndexInputFormat)
  extends CarbonRDD[(String, SegmentWrapper)](ss, Nil) {

  override protected def getPreferredLocations(split: Partition): Seq[String] = {
    val locations = split.asInstanceOf[IndexRDDPartition].getLocations
    if (locations != null) {
    } else {

  override protected def internalGetPartitions: Array[Partition] = {
    new DistributedPruneRDD(ss, indexInputFormat).partitions

  override def internalCompute(split: Partition,
      context: TaskContext): Iterator[(String, SegmentWrapper)] = {
    val inputSplits = split.asInstanceOf[IndexRDDPartition].inputSplit
    val segments =
    if (indexInputFormat.getInvalidSegments.size > 0) {
      // clear the segmentMap and from cache in executor when there are invalid segments
    val blockletMap = IndexStoreManager.getInstance
    val prunedSegments = blockletMap
      .pruneSegments(segments.toList.asJava, indexInputFormat.getFilterResolverIntf)
    val executorIP = s"${ }_${
    val cacheSize = if (CacheProvider.getInstance().getCarbonCache != null) {
    } else {
    val value = (executorIP + "_" + cacheSize.toString, new SegmentWrapper(prunedSegments))
Source File: DistributedCountRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.indexserver

import java.util.concurrent.Executors

import scala.collection.JavaConverters._
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration.Duration

import org.apache.hadoop.mapred.TaskAttemptID
import org.apache.hadoop.mapreduce.{InputSplit, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.sql.SparkSession

import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.cache.CacheProvider
import org.apache.carbondata.core.datastore.impl.FileFactory
import org.apache.carbondata.core.index.{IndexInputFormat, IndexStoreManager}
import org.apache.carbondata.core.util.{CarbonProperties, CarbonThreadFactory}
import org.apache.carbondata.spark.rdd.CarbonRDD

class DistributedCountRDD(@transient ss: SparkSession, indexInputFormat: IndexInputFormat)
  extends CarbonRDD[(String, String)](ss, Nil) {

  @transient private val LOGGER = LogServiceFactory.getLogService(classOf[DistributedPruneRDD]

  override protected def getPreferredLocations(split: Partition): Seq[String] = {
    if (split.asInstanceOf[IndexRDDPartition].getLocations != null) {
    } else {

  override def internalCompute(split: Partition,
      context: TaskContext): Iterator[(String, String)] = {
    val attemptId = new TaskAttemptID(DistributedRDDUtils.generateTrackerId,
      id, TaskType.MAP, split.index, 0)
    val attemptContext = new TaskAttemptContextImpl(FileFactory.getConfiguration, attemptId)
    val inputSplits = split.asInstanceOf[IndexRDDPartition].inputSplit
    val numOfThreads = CarbonProperties.getInstance().getNumOfThreadsForExecutorPruning
    val service = Executors
      .newFixedThreadPool(numOfThreads, new CarbonThreadFactory("IndexPruningPool", true))
    implicit val ec: ExecutionContextExecutor = ExecutionContext
    if (indexInputFormat.ifAsyncCall()) {
      // to clear cache of invalid segments during pre-priming in index server
    val futures = if (inputSplits.length <= numOfThreads) { {
        split => generateFuture(Seq(split))
    } else {
      DistributedRDDUtils.groupSplits(inputSplits, numOfThreads).map {
        splits => generateFuture(splits)
    // scalastyle:off awaitresult
    val results = Await.result(Future.sequence(futures), Duration.Inf).flatten
    // scalastyle:on awaitresult
    val executorIP = s"${ }_${
    val cacheSize = if (CacheProvider.getInstance().getCarbonCache != null) {
    } else {
    Iterator((executorIP + "_" + cacheSize.toString,

  override protected def internalGetPartitions: Array[Partition] = {
    new DistributedPruneRDD(ss, indexInputFormat).partitions

  private def generateFuture(split: Seq[InputSplit])
    (implicit executionContext: ExecutionContext) = {
    Future {
      val segments = { inputSplit =>
        val distributable = inputSplit.asInstanceOf[IndexInputSplitWrapper]
      val defaultIndex = IndexStoreManager.getInstance
        .getIndex(indexInputFormat.getCarbonTable, split.head
      defaultIndex.getBlockRowCount(defaultIndex, segments.toList.asJava, indexInputFormat

Source File: DistributedShowCacheRDD.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.indexserver

import scala.collection.JavaConverters._

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.DistributionUtil

import org.apache.carbondata.core.index.IndexStoreManager
import org.apache.carbondata.core.indexstore.blockletindex.BlockletIndexFactory
import org.apache.carbondata.hadoop.CarbonInputSplit
import org.apache.carbondata.spark.rdd.CarbonRDD

class DistributedShowCacheRDD(@transient private val ss: SparkSession,
    tableUniqueId: String,
    executorCache: Boolean)
  extends CarbonRDD[String](ss, Nil) {

  val executorsList: Array[String] = DistributionUtil
    .getExecutors(ss.sparkContext).flatMap {
      case (host, executors) => { executor =>
          s"executor_${ host }_$executor"

  override protected def getPreferredLocations(split: Partition): Seq[String] = {
    if (split.asInstanceOf[IndexRDDPartition].getLocations != null) {
    } else {

  override protected def internalGetPartitions: Array[Partition] = { {
      case (executor, idx) =>
        // create a dummy split for each executor to accumulate the cache size.
        val dummySplit = new CarbonInputSplit()
        new IndexRDDPartition(id, idx, List(dummySplit), Array(executor))

  override def internalCompute(split: Partition, context: TaskContext): Iterator[String] = {
    val indexes = IndexStoreManager.getInstance().getTableIndexForAllTables.asScala
    val tableList = tableUniqueId.split(",")
    val iterator = indexes.collect {
      case (tableId, tableIndexes) if tableUniqueId.isEmpty || tableList.contains(tableId) =>
        val sizeAndIndexLengths = tableIndexes.asScala
          .map { index =>
            val indexName = if (index.getIndexFactory.isInstanceOf[BlockletIndexFactory]) {
            } else {
              index.getIndexSchema.getRelationIdentifier.getDatabaseName + "_" + index
            if (executorCache) {
              val executorIP = s"${ }_${
              s"${ executorIP }:${ index.getIndexFactory.getCacheSize }:${
            } else {
Example 105
package org.apache.carbondata.indexserver

import scala.collection.JavaConverters._

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.DistributionUtil

import org.apache.carbondata.core.index.IndexStoreManager
import org.apache.carbondata.core.metadata.schema.table.CarbonTable
import org.apache.carbondata.hadoop.CarbonInputSplit
import org.apache.carbondata.spark.rdd.CarbonRDD

class InvalidateSegmentCacheRDD(@transient private val ss: SparkSession, carbonTable: CarbonTable,
    invalidSegmentIds: List[String]) extends CarbonRDD[String](ss, Nil) {

  val executorsList: Array[String] = DistributionUtil.getExecutors(ss.sparkContext).flatMap {
    case (host, executors) => {
        executor => s"executor_${host}_$executor"

  override def internalCompute(split: Partition,
      context: TaskContext): Iterator[String] = {
    IndexStoreManager.getInstance().clearInvalidSegments(carbonTable, invalidSegmentIds.asJava)

  override protected def getPreferredLocations(split: Partition): Seq[String] = {
    if (split.asInstanceOf[IndexRDDPartition].getLocations != null) {
    } else {

  override protected def internalGetPartitions: Array[Partition] = {
    if (invalidSegmentIds.isEmpty) {
    } else { {
        case (executor, idx) =>
          // create a dummy split for each executor to accumulate the cache size.
          val dummySplit = new CarbonInputSplit()
          new IndexRDDPartition(id, idx, List(dummySplit), Array(executor))
Source File: DefaultSource.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.metrics.source.MetricsHandler
import org.apache.spark.sql.sources.{
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

object DefaultSource {
  val MEMSQL_SOURCE_NAME          = "com.memsql.spark"
  val MEMSQL_SOURCE_NAME_SHORT    = "memsql"
  val MEMSQL_GLOBAL_OPTION_PREFIX = "spark.datasource.memsql."

class DefaultSource
    extends RelationProvider
    with DataSourceRegister
    with CreatableRelationProvider
    with LazyLogging {

  override def shortName: String = DefaultSource.MEMSQL_SOURCE_NAME_SHORT

  private def includeGlobalParams(sqlContext: SQLContext,
                                  params: Map[String, String]): Map[String, String] =
      case (params, (k, v)) if k.startsWith(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) =>
        params + (k.stripPrefix(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) -> v)
      case (params, _) => params

  override def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String]): BaseRelation = {
    val params  = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters))
    val options = MemsqlOptions(params)
    if (options.disablePushdown) {
      MemsqlReaderNoPushdown(MemsqlOptions.getQuery(params), options, sqlContext)
    } else {
      MemsqlReader(MemsqlOptions.getQuery(params), Nil, options, sqlContext)

  override def createRelation(sqlContext: SQLContext,
                              mode: SaveMode,
                              parameters: Map[String, String],
                              data: DataFrame): BaseRelation = {
    val opts = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters))
    val conf = MemsqlOptions(opts)

    val table = MemsqlOptions
        throw new IllegalArgumentException(
          s"To write a dataframe to MemSQL you must specify a table name via the '${MemsqlOptions.TABLE_NAME}' parameter"
    JdbcHelpers.prepareTableForWrite(conf, table, mode, data.schema)
    val isReferenceTable = JdbcHelpers.isReferenceTable(conf, table)
    val partitionWriterFactory =
      if (conf.onDuplicateKeySQL.isEmpty) {
        new LoadDataWriterFactory(table, conf)
      } else {
        new BatchInsertWriterFactory(table, conf)

    val schema        = data.schema
    var totalRowCount = 0L
    data.foreachPartition(partition => {
      val writer = partitionWriterFactory.createDataWriter(schema,
      try {
        partition.foreach(record => {
          totalRowCount += 1
      } catch {
        case e: Exception => {
          throw e

    createRelation(sqlContext, parameters)
Source File: MemsqlRDD.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.memsql.spark.SQLGen.VariableList
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}

case class MemsqlRDD(query: String,
                     variables: VariableList,
                     options: MemsqlOptions,
                     schema: StructType,
                     expectedOutput: Seq[Attribute],
                     @transient val sc: SparkContext)
    extends RDD[Row](sc, Nil) {

  override protected def getPartitions: Array[Partition] =
    MemsqlQueryHelpers.GetPartitions(options, query, variables)

  override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = {
    var closed                     = false
    var rs: ResultSet              = null
    var stmt: PreparedStatement    = null
    var conn: Connection           = null
    var partition: MemsqlPartition = rawPartition.asInstanceOf[MemsqlPartition]

    def tryClose(name: String, what: AutoCloseable): Unit = {
      try {
        if (what != null) { what.close() }
      } catch {
        case e: Exception => logWarning(s"Exception closing $name", e)

    def close(): Unit = {
      if (closed) { return }
      tryClose("resultset", rs)
      tryClose("statement", stmt)
      tryClose("connection", conn)
      closed = true

    context.addTaskCompletionListener { context =>

    conn = JdbcUtils.createConnectionFactory(partition.connectionInfo)()
    stmt = conn.prepareStatement(partition.query)
    JdbcHelpers.fillStatement(stmt, partition.variables)
    rs = stmt.executeQuery()

    var rowsIter = JdbcUtils.resultSetToRows(rs, schema)

    if (expectedOutput.nonEmpty) {
      val schemaDatatypes   =
      val expectedDatatypes =

      if (schemaDatatypes != expectedDatatypes) {
        val columnEncoders = {
          case ((_: StringType, _: NullType), _)     => ((_: Row) => null)
          case ((_: ShortType, _: BooleanType), i)   => ((r: Row) => r.getShort(i) != 0)
          case ((_: IntegerType, _: BooleanType), i) => ((r: Row) => r.getInt(i) != 0)
          case ((_: LongType, _: BooleanType), i)    => ((r: Row) => r.getLong(i) != 0)

          case ((l, r), i) => {
            options.assert(l == r, s"MemsqlRDD: unable to encode ${l} into ${r}")
            ((r: Row) => r.get(i))

        rowsIter = rowsIter
          .map(row => Row.fromSeq(

    CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close)

Source File: DatasourceRDD.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource.receiver

import org.apache.spark.partial.{BoundedDouble, CountEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.streaming.datasource.config.ParametersUtils
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetOperator}
import org.apache.spark.{Logging, Partition, TaskContext}

class DatasourceRDD(
                     @transient sqlContext: SQLContext,
                     inputSentences: InputSentences,
                     datasourceParams: Map[String, String]
                   ) extends RDD[Row](sqlContext.sparkContext, Nil) with Logging with ParametersUtils {

  private var totalCalculated: Option[Long] = None

  private val InitTableName = "initTable"
  private val LimitedTableName = "limitedTable"
  private val TempInitQuery = s"select * from $InitTableName"

  val dataFrame = inputSentences.offsetConditions.fold(sqlContext.sql(inputSentences.query)) { case offset =>
    val parsedQuery = parseInitialQuery
    val conditionsSentence = offset.fromOffset.extractConditionSentence(parsedQuery)
    val orderSentence = offset.fromOffset.extractOrderSentence(parsedQuery, inverse = offset.limitRecords.isEmpty)
    val limitSentence = inputSentences.extractLimitSentence

    sqlContext.sql(parsedQuery + conditionsSentence + orderSentence + limitSentence)

  private def parseInitialQuery: String = {
    if (inputSentences.query.toUpperCase.contains("WHERE") ||
      inputSentences.query.toUpperCase.contains("ORDER") ||
    ) {
    } else inputSentences.query

  def progressInputSentences: InputSentences = {
    if (!dataFrame.rdd.isEmpty()) {
      inputSentences.offsetConditions.fold(inputSentences) { case offset =>

        val offsetValue = if (offset.limitRecords.isEmpty)
        else {
          val limitedQuery = s"select * from $LimitedTableName order by ${} " +
            s"${OffsetOperator.toInverseOrderOperator(offset.fromOffset.operator)} limit 1"


        inputSentences.copy(offsetConditions = Option(offset.copy(fromOffset = offset.fromOffset.copy(
          value = Option(offsetValue),
          operator = OffsetOperator.toProgressOperator(offset.fromOffset.operator)))))
    } else inputSentences

  override def isEmpty(): Boolean = {
    totalCalculated.fold {
      withScope {
        partitions.length == 0 || take(1).length == 0
    } { total => total == 0L }

  override def getPartitions: Array[Partition] = dataFrame.rdd.partitions

  override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = dataFrame.rdd.compute(thePart, context)

  override def getPreferredLocations(thePart: Partition): Seq[String] = dataFrame.rdd.preferredLocations(thePart)
Source File: NetezzaRDD.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up

import java.sql.Connection
import java.util.Properties

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.{Partition, SparkContext, TaskContext}

  override def compute(thePart: Partition, context: TaskContext): Iterator[Row] =
    new Iterator[Row] {
      var closed = false
      var finished = false
      var gotNext = false
      var nextValue: Row = null

      context.addTaskCompletionListener { context => close() }
      val part = thePart.asInstanceOf[NetezzaPartition]
      val conn = getConnection()
      val reader = new NetezzaDataReader(conn, table, columns, filters, part, schema)

      def getNext(): Row = {
        if (reader.hasNext) {

        } else {
          finished = true

      def close() {
        if (closed) return
        try {
          if (null != reader) {
        } catch {
          case e: Exception => logWarning("Exception closing Netezza record reader", e)
        try {
          if (null != conn) {
          logInfo("closed connection")
        } catch {
          case e: Exception => logWarning("Exception closing connection", e)

      override def hasNext: Boolean = {
        if (!finished) {
          if (!gotNext) {
            nextValue = getNext()
            if (finished) {
            gotNext = true

      override def next(): Row = {
        if (!hasNext) {
          throw new NoSuchElementException("End of stream")
        gotNext = false
Source File: StratifiedRepartitionSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.


import{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

  import session.implicits._

  val values = "values"
  val colors = "colors"
  val const = "const"

  lazy val input = Seq(
    (0, "Blue", 2),
    (0, "Red", 2),
    (0, "Green", 2),
    (1, "Purple", 2),
    (1, "Orange", 2),
    (1, "Indigo", 2),
    (2, "Violet", 2),
    (2, "Black", 2),
    (2, "White", 2),
    (3, "Gray", 2),
    (3, "Yellow", 2),
    (3, "Cerulean", 2)
  ).toDF(values, colors, const)

  test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
    val inputSchema = new StructType()
      .add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
    val inputEnc = RowEncoder(inputSchema)
    val valuesFieldIndex = inputSchema.fieldIndex(values)
    val numPartitions = 3
    val trainData = input.repartition(numPartitions).select(values, colors, const)
      .mapPartitions(iter => {
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        // Remove all instances of 0 class on partition 1
        if (partId == 1) {
          iter.flatMap(row => {
            if (row.getInt(valuesFieldIndex) <= 0)
            else Some(row)
        } else {
          // Add back at least 3 instances on other partitions
          val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
    // Some debug to understand what data is on which partition
    trainData.foreachPartition { rows =>
      rows.foreach { row =>
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        println(s"Row: $row partition id: $partId")
    val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
    // Assert stratified data contains all keys across all partitions, with extra count
    // for it to be evaluated
      .mapPartitions(iter => {
        val actualLabels = => row.getInt(valuesFieldIndex))
        val expectedLabels = (0 to 3).toList
        if (actualLabels != expectedLabels)
          throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
    val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
    assert(stratifiedMixedInputData.count() >= trainData.count())
    val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
    assert(stratifiedOriginalInputData.count() == trainData.count())

  def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
    new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

Example 111
Source File: ReorderedPartitionsRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import is.hail.utils.FastSeq
import org.apache.spark.rdd.RDD
import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext}

import scala.reflect.ClassTag

case class ReorderedPartitionsRDDPartition(index: Int, oldPartition: Partition) extends Partition

class ReorderedPartitionsRDD[T](@transient var prev: RDD[T], @transient val oldIndices: Array[Int])(implicit tct: ClassTag[T])
  extends RDD[T](prev.sparkContext, Nil) {

  override def getPartitions: Array[Partition] = {
    val parentPartitions = dependencies.head.rdd.asInstanceOf[RDD[T]].partitions
    Array.tabulate(oldIndices.length) { i =>
      val oldIndex = oldIndices(i)
      val oldPartition = parentPartitions(oldIndex)
      ReorderedPartitionsRDDPartition(i, oldPartition)

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val parent = dependencies.head.rdd.asInstanceOf[RDD[T]]
    parent.compute(split.asInstanceOf[ReorderedPartitionsRDDPartition].oldPartition, context)

  override def getDependencies: Seq[Dependency[_]] = FastSeq(new NarrowDependency[T](prev) {
    override def getParents(partitionId: Int): Seq[Int] = FastSeq(oldIndices(partitionId))

  override def clearDependencies() {
    prev = null

  override def getPreferredLocations(partition: Partition): Seq[String] =
Source File: MapPartitionsWithValueRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import org.apache.spark.rdd.RDD
import org.apache.spark.{Partition, TaskContext}

import scala.annotation.meta.param
import scala.reflect.ClassTag

case class MapPartitionsWithValueRDDPartition[V](
  parentPartition: Partition,
  value: V) extends Partition {
  def index: Int = parentPartition.index

class MapPartitionsWithValueRDD[T: ClassTag, U: ClassTag, V](
  var prev: RDD[T],
  @(transient @param) values: Array[V],
  f: (Int, V, Iterator[T]) => Iterator[U],
  preservesPartitioning: Boolean) extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

  override def getPartitions: Array[Partition] = {
    firstParent[T] => MapPartitionsWithValueRDDPartition(p, values(p.index)))

  override def compute(split: Partition, context: TaskContext): Iterator[U] = {
    val p = split.asInstanceOf[MapPartitionsWithValueRDDPartition[V]]
    f(split.index, p.value, firstParent[T].iterator(p.parentPartition, context))

  override def clearDependencies() {
    prev = null
Source File: BlockedRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import is.hail.utils._
import org.apache.spark.rdd.RDD
import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext}

import scala.language.existentials
import scala.reflect.ClassTag

case class BlockedRDDPartition(@transient rdd: RDD[_],
  index: Int,
  first: Int,
  last: Int) extends Partition {
  require(first <= last)

  val parentPartitions: Array[Partition] =

  def range: Range = first to last

class BlockedRDD[T](@transient var prev: RDD[T],
  @transient val partFirst: Array[Int],
  @transient val partLast: Array[Int]
)(implicit tct: ClassTag[T]) extends RDD[T](prev.sparkContext, Nil) {
  assert(partFirst.length == partLast.length)

  override def getPartitions: Array[Partition] = {
    Array.tabulate[Partition](partFirst.length)(i =>
      BlockedRDDPartition(prev, i, partFirst(i), partLast(i)))

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val parent = dependencies.head.rdd.asInstanceOf[RDD[T]]
    split.asInstanceOf[BlockedRDDPartition].parentPartitions.iterator.flatMap(p =>
      parent.iterator(p, context))

  override def getDependencies: Seq[Dependency[_]] = {
    FastSeq(new NarrowDependency(prev) {
      def getParents(id: Int): Seq[Int] =

  override def clearDependencies() {
    prev = null

    val prevPartitions = prev.partitions
    val range = partition.asInstanceOf[BlockedRDDPartition].range

    val locationAvail = range.flatMap(i =>

    if (locationAvail.isEmpty)
      return FastSeq.empty[String]

    val m = locationAvail.values.max
    locationAvail.filter(_._2 == m)
Example 114
Source File: IndexReadRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import is.hail.backend.spark.SparkBackend
import is.hail.utils.Interval
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD

import scala.reflect.ClassTag

case class IndexedFilePartition(index: Int, file: String, bounds: Option[Interval]) extends Partition

class IndexReadRDD[T: ClassTag](
  @transient val partFiles: Array[String],
  @transient val intervalBounds: Option[Array[Interval]],
  f: (IndexedFilePartition, TaskContext) => T
) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) {
  def getPartitions: Array[Partition] =
    Array.tabulate(partFiles.length) { i =>
      IndexedFilePartition(i, partFiles(i),

  override def compute(
    split: Partition, context: TaskContext
  ): Iterator[T] = {
    Iterator.single(f(split.asInstanceOf[IndexedFilePartition], context))
Source File: MultiWayZipPartitionsRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import org.apache.spark.rdd.RDD
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}

import scala.reflect.ClassTag

object MultiWayZipPartitionsRDD {
  def apply[T: ClassTag , V: ClassTag](
    rdds: IndexedSeq[RDD[T]]
  )(f: (Array[Iterator[T]]) => Iterator[V]): MultiWayZipPartitionsRDD[T, V] = {
    new MultiWayZipPartitionsRDD(rdds.head.sparkContext, rdds, f)

private case class MultiWayZipPartition(val index: Int, val partitions: IndexedSeq[Partition])
  extends Partition

class MultiWayZipPartitionsRDD[T: ClassTag, V: ClassTag](
  sc: SparkContext,
  var rdds: IndexedSeq[RDD[T]],
  var f: (Array[Iterator[T]]) => Iterator[V]
) extends RDD[V](sc, => new OneToOneDependency(x))) {
  require(rdds.length > 0)
  private val numParts = rdds(0).partitions.length
  require(rdds.forall(rdd => rdd.partitions.length == numParts))

  override val partitioner = None

  override def getPartitions: Array[Partition] = {
    Array.tabulate[Partition](numParts) { i =>
      MultiWayZipPartition(i, => rdd.partitions(i)))

  override def compute(s: Partition, tc: TaskContext) = {
    val partitions = s.asInstanceOf[MultiWayZipPartition].partitions
    val arr = Array.tabulate(rdds.length)(i => rdds(i).iterator(partitions(i), tc))

  override def clearDependencies() {
    rdds = null
    f = null
Source File: OriginUnionRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

private[hail] class OriginUnionPartition(
  val index: Int,
  val originIdx: Int,
  val originPart: Partition
) extends Partition

class OriginUnionRDD[T: ClassTag, S: ClassTag](
  sc: SparkContext,
  var rdds: IndexedSeq[RDD[T]],
  f: (Int, Int, Iterator[T]) => Iterator[S]
) extends RDD[S](sc, Nil) {
  override def getPartitions: Array[Partition] = {
    val arr = new Array[Partition](
    var i = 0
    for ((rdd, rddIdx) <- rdds.zipWithIndex; part <- rdd.partitions) {
      arr(i) = new OriginUnionPartition(i, rddIdx, part)
      i += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var i = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, i, rdd.partitions.length)
      i += rdd.partitions.length

  override def compute(s: Partition, tc: TaskContext): Iterator[S] = {
    val p = s.asInstanceOf[OriginUnionPartition]
    f(p.originIdx, p.originPart.index, parent[T](p.originIdx).iterator(p.originPart, tc))

  override def clearDependencies() {
    rdds = null
Source File: ProtoParquetRDD.scala    From sparksql-protobuf   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.parquet.proto.spark

import com.github.saurfang.parquet.proto.ProtoMessageParquetInputFormat
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.parquet.proto.ProtoReadSupport
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{NewHadoopRDD, RDD}
import org.apache.spark.{Partition, SparkContext, TaskContext}

import scala.reflect.ClassTag

class ProtoParquetRDD[T <: AbstractMessage : ClassTag](
                                                        sc: SparkContext,
                                                        input: String,
                                                        protoClass: Class[T],
                                                        @transient conf: Configuration
                                                        ) extends RDD[T](sc, Nil) {

  def this(sc: SparkContext, input: String, protoClass: Class[T]) = {
    this(sc, input, protoClass, sc.hadoopConfiguration)

  lazy private[this] val rdd = {
    val jconf = new JobConf(conf)
    FileInputFormat.setInputPaths(jconf, input)
    ProtoReadSupport.setProtobufClass(jconf, protoClass.getName)

    new NewHadoopRDD(sc, classOf[ProtoMessageParquetInputFormat[T]], classOf[Void], protoClass, jconf)

  override def compute(split: Partition, context: TaskContext): Iterator[T] = rdd.compute(split, context).map(_._2)

  override protected def getPartitions: Array[Partition] = rdd.getPartitions
Source File: Neo4jRDD.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.{Partition, SparkContext, TaskContext}

private class Neo4jRDD(
    sc: SparkContext,
    val query: String,
    val neo4jConfig: Neo4jConfig,
    val parameters: Map[String, Any] = Map.empty,
    partitions: Partitions = Partitions())
    extends RDD[Row](sc, Nil) {

  override def compute(partition: Partition, context: TaskContext): Iterator[Row] = {

    val neo4jPartition: Neo4jPartition = partition.asInstanceOf[Neo4jPartition]

    Executor.execute(neo4jConfig, query, parameters ++ neo4jPartition.window).sparkRows

  override protected def getPartitions: Array[Partition] = {
    val p = partitions.effective()
    Range(0, p.partitions.toInt).map(idx => new Neo4jPartition(idx, p.skip(idx), p.limit(idx))).toArray

  override def toString(): String = s"Neo4jRDD partitions $partitions $query using $parameters"
Source File: SlidingRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
    "Window size and step must be greater than 0, " +
      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)
      .sliding(windowSize, step)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.length
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
    } else {
      val w1 = windowSize - 1
      // Get partition sizes and first w1 elements.
      val (sizes, heads) = parent.mapPartitions { iter =>
        val w1Array = iter.take(w1).toArray
        Iterator.single((w1Array.length + iter.length, w1Array))
      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
      var i = 0
      var cumSize = 0
      var partitionIndex = 0
      while (i < n) {
        val mod = cumSize % step
        val offset = if (mod == 0) 0 else step - mod
        val size = sizes(i)
        if (offset < size) {
          val tail = mutable.ListBuffer.empty[T]
          // Keep appending to the current tail until it has w1 elements.
          var j = i + 1
          while (j < n && tail.length < w1) {
            tail ++= heads(j).take(w1 - tail.length)
            j += 1
          if (sizes(i) + tail.length >= offset + windowSize) {
            partitions +=
              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
            partitionIndex += 1
        cumSize += size
        i += 1

  // TODO: Override methods such as aggregate, which only requires one Spark job.
Source File: CommitFailureTestSource.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          path: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(path, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true

          override def write(row: Row): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")


          override def close(): Unit = {
            sys.error("Intentional task commitment failure for testing purpose.")

      override def getFileExtension(context: TaskAttemptContext): String = ""

  override def shortName(): String = "commit-failure-test"
Source File: MonotonicallyIncreasingID.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}

  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
Source File: ShuffledHashJoinExec.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class ShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

  override def requiredChildDistribution: Seq[Distribution] =
    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
    buildTime += (System.nanoTime() - start) / 1000000
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener(_ => relation.close())

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows)
Source File: StateStoreRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

class StateStoreRDD[T: ClassTag, U: ClassTag](
    dataRDD: RDD[T],
    storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
    checkpointLocation: String,
    operatorId: Long,
    storeVersion: Long,
    keySchema: StructType,
    valueSchema: StructType,
    sessionState: SessionState,
    @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
  extends RDD[U](dataRDD) {

  private val storeConf = new StateStoreConf(sessionState.conf)

  // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
  private val confBroadcast = dataRDD.context.broadcast(
    new SerializableConfiguration(sessionState.newHadoopConf()))

  override protected def getPartitions: Array[Partition] = dataRDD.partitions

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
    store = StateStore.get(
      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
Source File: ReferenceSort.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryExecNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = child.outputPartitioning
Source File: SparkHadoopMapRedUtil.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mapred


import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging

object SparkHadoopMapRedUtil extends Logging {

  private val user = UserGroupInformation.getCurrentUser.getShortUserName
  def commitTask(
      committer: MapReduceOutputCommitter,
      mrTaskContext: MapReduceTaskAttemptContext,
      jobId: Int,
      splitId: Int): Unit = {

    val mrTaskAttemptID = mrTaskContext.getTaskAttemptID

    // Called after we have decided to commit
    def performCommit(): Unit = {
      try {
        logInfo(s"$mrTaskAttemptID: Committed")
      } catch {
        case cause: IOException =>
          logError(s"Error committing the output of task: $mrTaskAttemptID", cause)
          throw cause

    // First, check whether the task's output has already been committed by some other attempt
    if (committer.needsTaskCommit(mrTaskContext)) {
      val shouldCoordinateWithDriver: Boolean = {
        val sparkConf = SparkEnv.get(user).conf
        // We only need to coordinate with the driver if there are concurrent task attempts.
        // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)

      if (shouldCoordinateWithDriver) {
        val outputCommitCoordinator = SparkEnv.get(user).outputCommitCoordinator
        val taskAttemptNumber = TaskContext.get().attemptNumber()
        val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber)

        if (canCommit) {
        } else {
          val message =
            s"$mrTaskAttemptID: Not committed because the driver did not authorize commit"
          // We need to abort the task so that the driver can reschedule new attempts, if necessary
          throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber)
      } else {
        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
    } else {
      // Some other attempt committed the output, so we do nothing and signal success
      logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
Example 126
Source File: taskListeners.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.EventListener

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi

class TaskCompletionListenerException(
    errorMessages: Seq[String],
    val previousError: Option[Throwable] = None)
  extends RuntimeException {

  override def getMessage: String = {
    if (errorMessages.size == 1) {
    } else { { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
    } + { e =>
      "\n\nPrevious exception in task: " + e.getMessage + "\n" +
        e.getStackTrace.mkString("\t", "\n\t", "")
Source File: ZippedWithIndexRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    val parentIter = firstParent[T].iterator(split.prev, context)
    Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
Source File: UnionRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
    } else {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Source File: PartitionwiseSampledRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    preservesPartitioning: Boolean,
    @transient private val seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: PartitionerAwareUnionRDD.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map { index =>
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) =>
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Example 131
Source File: MemoryTestingUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.memory

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}

object MemoryTestingUtils {
  def fakeTaskContext(env: SparkEnv): TaskContext = {
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
    new TaskContextImpl(
      stageId = 0,
      partitionId = 0,
      taskAttemptId = 0,
      attemptNumber = 0,
      taskMemoryManager = taskMemoryManager,
      localProperties = new Properties,
      metricsSystem = env.metricsSystem)
Source File: FakeTask.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(
    stageId: Int,
    partitionId: Int,
    prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) {
  override def runTask(context: TaskContext, user: String): Int = 0
  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
  TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Seconds, Span}

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Source File: PartitionPruningRDDSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Source File: SlidingRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{TaskContext, Partition}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T])
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.size
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty))
    } else {
      val n1 = n - 1
      val w1 = windowSize - 1
      // Get the first w1 items of each partition, starting from the second partition.
      val nextHeads =
        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
      val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
      var i = 0
      var partitionIndex = 0
      while (i < n1) {
        var j = i
        val tail = mutable.ListBuffer[T]()
        // Keep appending to the current tail until appended a head of size w1.
        while (j < n1 && nextHeads(j).size < w1) {
          tail ++= nextHeads(j)
          j += 1
        if (j < n1) {
          tail ++= nextHeads(j)
          j += 1
        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail)
        partitionIndex += 1
        // Skip appended heads.
        i = j
      // If the head of last partition has size w1, we also need to add this partition.
      if (nextHeads.last.size == w1) {
        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty)

Example 136
Source File: HashShuffleReader.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext)
  extends ShuffleReader[K, C]
  require(endPartition == startPartition + 1,
    "Hash shuffle currently only supports fetching one partition")

  private val dep = handle.dependency

  override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
      case None =>
Source File: SortShuffleWriter.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.shuffle.sort

import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
    shuffleBlockResolver: IndexShuffleBlockResolver,
    handle: BaseShuffleHandle[K, V, C],
    mapId: Int,
    context: TaskContext)
  extends ShuffleWriter[K, V] with Logging {

  private val dep = handle.dependency

  private val blockManager = SparkEnv.get.blockManager

  private var sorter: ExternalSorter[K, V, _] = null

  // Are we in the process of stopping? Because map tasks can call stop() with success = true
  // and then call stop() with success = false if they get an exception, we want to make sure
  // we don't try deleting files, etc twice.
  private var stopping = false

  private var mapStatus: MapStatus = null

  private val writeMetrics = new ShuffleWriteMetrics()
  context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)

  override def stop(success: Boolean): Option[MapStatus] = {
    try {
      if (stopping) {
        return None
      stopping = true
      if (success) {
        return Option(mapStatus)
      } else {
        // The map task failed, so delete our output data.
        shuffleBlockResolver.removeDataByMap(dep.shuffleId, mapId)
        return None
    } finally {
      // Clean up our sorter, which may have its own intermediate files
      if (sorter != null) {
        val startTime = System.nanoTime()
          _.incShuffleWriteTime(System.nanoTime - startTime))
        sorter = null
Source File: ActiveJob.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.TaskContext
import org.apache.spark.util.CallSite

private[spark] class ActiveJob(
    val jobId: Int,
    val finalStage: ResultStage,
    val func: (TaskContext, Iterator[_]) => _,
    val partitions: Array[Int],
    val callSite: CallSite,
    val listener: JobListener,
    val properties: Properties) {

  val numPartitions = partitions.length
  val finished = Array.fill[Boolean](numPartitions)(false)
  var numFinished = 0
Source File: SampledRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.{Partition, TaskContext}

@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
  override val index: Int = prev.index

@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
private[spark] class SampledRDD[T: ClassTag](
    prev: RDD[T],
    withReplacement: Boolean,
    frac: Double,
    seed: Int)
  extends RDD[T](prev) {

  override def getPartitions: Array[Partition] = {
    val rg = new Random(seed)
    firstParent[T] => new SampledRDDPartition(x, rg.nextInt))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
    val split = splitIn.asInstanceOf[SampledRDDPartition]
    if (withReplacement) {
      // For large datasets, the expected number of occurrences of each element in a sample with
      // replacement is Poisson(frac). We use that to get a count for each element.
      val poisson = new PoissonDistribution(frac)

      firstParent[T].iterator(split.prev, context).flatMap { element =>
        val count = poisson.sample()
        if (count == 0) {
          Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
        } else {
    } else { // Sampling without replacement
      val rand = new Random(split.seed)
      firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
Source File: ZippedWithIndexRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1, // do not need to count the last partition
        allowLocal = false
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    firstParent[T].iterator(split.prev, context) { x =>
      (x._1, split.startIndex + x._2)
Source File: UnionRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient rdd: RDD[T],
    val parentRddIndex: Int,
    @transient parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
Example 142
Source File: PartitionwiseSampledRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    @transient preservesPartitioning: Boolean,
    @transient seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: PartitionerAwareUnionRDD.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.length > 0)
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map(index => {
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) => {
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Source File: FakeTask.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
  override def runTask(context: TaskContext): Int = 0

  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, 0, 0, 0, null)
Example 145
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Span, Seconds}

import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("master", "local[2,4]")
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Example 146
Source File: PartitionPruningRDDSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Source File: HBaseSimpleRDD.scala    From spark-hbase-connector   with Apache License 2.0 5 votes vote down vote up
package it.nerdammer.spark.hbase

import it.nerdammer.spark.hbase.conversion.FieldReader
import org.apache.hadoop.hbase.CellUtil
import org.apache.hadoop.hbase.client.Result
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.rdd.{NewHadoopRDD, RDD}
import org.apache.spark.{Partition, TaskContext}

import scala.reflect.ClassTag

class HBaseSimpleRDD[R: ClassTag](hadoopHBase: NewHadoopRDD[ImmutableBytesWritable, Result], builder: HBaseReaderBuilder[R], saltingLength: Int = 0)
                       (implicit mapper: FieldReader[R], saltingProvider: SaltingProviderFactory[String]) extends RDD[R](hadoopHBase) {

  override def getPartitions: Array[Partition] = firstParent[(ImmutableBytesWritable, Result)].partitions

  override def compute(split: Partition, context: TaskContext) = {
    // val cleanConversion = sc.clean ---> next version
    firstParent[(ImmutableBytesWritable, Result)].iterator(split, context)
      .map(e => conversion(e._1, e._2))

  def conversion(key: ImmutableBytesWritable, row: Result) = {

    val columnNames = HBaseUtils.chosenColumns(builder.columns, mapper.columns)

    val columnNamesFC = HBaseUtils.columnsWithFamily(builder.columnFamily, columnNames)

    val columns = columnNamesFC
      .map(t => (Bytes.toBytes(t._1), Bytes.toBytes(t._2)))
      .map(t => if(row.containsColumn(t._1, t._2)) Some(CellUtil.cloneValue(row.getColumnLatestCell(t._1, t._2)).array) else None)
      .toList :: columns)
Example 148
Source File: SlidingRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{TaskContext, Partition}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T])
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.size
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty))
    } else {
      val n1 = n - 1
      val w1 = windowSize - 1
      // Get the first w1 items of each partition, starting from the second partition.
      val nextHeads =
        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
      val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
      var i = 0
      var partitionIndex = 0
      while (i < n1) {
        var j = i
        val tail = mutable.ListBuffer[T]()
        // Keep appending to the current tail until appended a head of size w1.
        while (j < n1 && nextHeads(j).size < w1) {
          tail ++= nextHeads(j)
          j += 1
        if (j < n1) {
          tail ++= nextHeads(j)
          j += 1
        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail)
        partitionIndex += 1
        // Skip appended heads.
        i = j
      // If the head of last partition has size w1, we also need to add this partition.
      if (nextHeads.last.size == w1) {
        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty)

  // TODO: Override methods such as aggregate, which only requires one Spark job.
Source File: MonotonicallyIncreasingID.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{LongType, DataType}

  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initInternal(): Unit = {
    count = 0L
    partitionMask = TaskContext.getPartitionId().toLong << 33

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")

    ev.isNull = "false"
      final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm;
Example 150
Source File: randomExpressions.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

case class Randn(seed: Long) extends RDG {
  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  def this() = this(Utils.random.nextLong())

  def this(seed: Expression) = this(seed match {
    case IntegerLiteral(s) => s
    case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val rngTerm = ctx.freshName("rng")
    val className = classOf[XORShiftRandom].getName
    ctx.addMutableState(className, rngTerm,
      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
    ev.isNull = "false"
      final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian();
Source File: BroadcastLeftSemiJoinHash.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class BroadcastLeftSemiJoinHash(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    left: SparkPlan,
    right: SparkPlan,
    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

  override private[sql] lazy val metrics = Map(
    "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
    "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numLeftRows = longMetric("numLeftRows")
    val numRightRows = longMetric("numRightRows")
    val numOutputRows = longMetric("numOutputRows")

    val input = right.execute().map { row =>
      numRightRows += 1

    if (condition.isEmpty) {
      val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric)
      val broadcastedRelation = sparkContext.broadcast(hashSet)

      left.execute().mapPartitions { streamIter =>
        hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows)
    } else {
      val hashRelation =
        HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size)
      val broadcastedRelation = sparkContext.broadcast(hashRelation)

      left.execute().mapPartitions { streamIter =>
        val hashedRelation = broadcastedRelation.value
        hashedRelation match {
          case unsafe: UnsafeHashedRelation =>
          case _ =>
        hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
Source File: ActiveJob.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.TaskContext
import org.apache.spark.util.CallSite

private[spark] class ActiveJob(
    val jobId: Int,//每个作业都分配一个唯一的I
    val finalStage: ResultStage,//最终的stage
    val func: (TaskContext, Iterator[_]) => _,//作用于最后一个stage上的函数
    val partitions: Array[Int],//分区列表,注意这里表示从多少个分区读入数据并进行处理
    val callSite: CallSite,
    val listener: JobListener,//Job监听器
    val properties: Properties) {
  val numPartitions = partitions.length
  val finished = Array.fill[Boolean](numPartitions)(false)
  var numFinished = 0
Source File: SampledRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.{Partition, TaskContext}

@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
  override val index: Int = prev.index

@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
private[spark] class SampledRDD[T: ClassTag](
    prev: RDD[T],
    withReplacement: Boolean,
    frac: Double,
    seed: Int)
  extends RDD[T](prev) {

  override def getPartitions: Array[Partition] = {
    val rg = new Random(seed)
    firstParent[T] => new SampledRDDPartition(x, rg.nextInt))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
    val split = splitIn.asInstanceOf[SampledRDDPartition]
    if (withReplacement) {
      // For large datasets, the expected number of occurrences of each element in a sample with
      // replacement is Poisson(frac). We use that to get a count for each element.
      val poisson = new PoissonDistribution(frac)

      firstParent[T].iterator(split.prev, context).flatMap { element =>
        val count = poisson.sample()
        if (count == 0) {
          Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
        } else {
    } else { // Sampling without replacement
      val rand = new Random(split.seed)
      firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
Source File: ZippedWithIndexRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    firstParent[T].iterator(split.prev, context) { x =>
      (x._1, split.startIndex + x._2)
Example 155
Source File: MapPartitionsWithPreparationRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Partition, Partitioner, TaskContext}

  override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
    val prepared =
      if (preparedArguments.isEmpty) {
      } else {
    val parentIterator = firstParent[T].iterator(partition, context)
    executePartition(context, partition.index, prepared, parentIterator)
Source File: LocalRDDCheckpointData.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
import{RDDBlockId, StorageLevel}
import org.apache.spark.util.Utils

  def transformStorageLevel(level: StorageLevel): StorageLevel = {
    // If this RDD is to be cached off-heap, fail fast since we cannot provide any
    // correctness guarantees about subsequent computations after the first one
    if (level.useOffHeap) {
      throw new SparkException("Local checkpointing is not compatible with off-heap caching.")

    StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
Example 157
Source File: MapPartitionsRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false)
  extends RDD[U](prev) {

  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
  override def getPartitions: Array[Partition] = firstParent[T].partitions
  override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context))
Source File: UnionRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient rdd: RDD[T],
    val parentRddIndex: Int,
    @transient parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Source File: PartitionwiseSampledRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    @transient preservesPartitioning: Boolean,
    @transient seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: CheckpointRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, SparkContext, TaskContext}

private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext)
  extends RDD[T](sc, Nil) {

  // CheckpointRDD should not be checkpointed again
  override def doCheckpoint(): Unit = { }
  override def checkpoint(): Unit = { }
  override def localCheckpoint(): this.type = this

  // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the
  // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods.
  // scalastyle:off
  protected override def getPartitions: Array[Partition] = ???
  override def compute(p: Partition, tc: TaskContext): Iterator[T] = ???
Example 161
Source File: FakeTask.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(
    stageId: Int,
    prefLocs: Seq[TaskLocation] = Nil)
  extends Task[Int](stageId, 0, 0, Seq.empty) {//扩展一个Task类
  override def runTask(context: TaskContext): Int = 0
  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {//假任务
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, 0, stageAttemptId, 0, null)
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.{Span, Seconds}

import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with Timeouts {

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("master", "local[2,4]")
      .set("spark.speculation", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

  test("exception thrown in OutputCommitter.commitTask()") {//异常抛出
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Source File: PartitionPruningRDDSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {//修剪的分区设置的正确性

    val rdd = new RDD[Int](sc, Nil) {//列表结尾为Nil
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {//修剪分区可以联合
    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Source File: WithCalcTransactionLogging.scala    From languagedetector   with MIT License 5 votes vote down vote up
package biz.meetmatch.decorators

import biz.meetmatch.logging.BusinessLogger
import org.apache.spark.TaskContext

object WithCalcTransactionLogging {
  def apply[B](category: String, id: String, message: String = "")(f: => B)(implicit module: Class[_]): B = {
    val businessLogger = new BusinessLogger(module.getName)

    val taskContext = TaskContext.get
    businessLogger.transactionStarted(category, id, taskContext.stageId, taskContext.partitionId, taskContext.taskAttemptId, message)
    val result = f
    businessLogger.transactionStopped(category, id)

Example 165
Source File: SlidingRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
    "Window size and step must be greater than 0, " +
      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)
      .sliding(windowSize, step)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.length
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
    } else {
      val w1 = windowSize - 1
      // Get partition sizes and first w1 elements.
      val (sizes, heads) = parent.mapPartitions { iter =>
        val w1Array = iter.take(w1).toArray
        Iterator.single((w1Array.length + iter.length, w1Array))
      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
      var i = 0
      var cumSize = 0
      var partitionIndex = 0
      while (i < n) {
        val mod = cumSize % step
        val offset = if (mod == 0) 0 else step - mod
        val size = sizes(i)
        if (offset < size) {
          val tail = mutable.ListBuffer.empty[T]
          // Keep appending to the current tail until it has w1 elements.
          var j = i + 1
          while (j < n && tail.length < w1) {
            tail ++= heads(j).take(w1 - tail.length)
            j += 1
          if (sizes(i) + tail.length >= offset + windowSize) {
            partitions +=
              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
            partitionIndex += 1
        cumSize += size
        i += 1

Example 166
Source File: CommitFailureTestSource.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

class CommitFailureTestSource extends SimpleTextSource {
  override def prepareWrite(
      sparkSession: SparkSession,
      job: Job,
      options: Map[String, String],
      dataSchema: StructType): OutputWriterFactory =
    new OutputWriterFactory {
      override def newInstance(
          path: String,
          dataSchema: StructType,
          context: TaskAttemptContext): OutputWriter = {
        new SimpleTextOutputWriter(path, dataSchema, context) {
          var failed = false
          TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
            failed = true
            SimpleTextRelation.callbackCalled = true

          override def write(row: InternalRow): Unit = {
            if (SimpleTextRelation.failWriter) {
              sys.error("Intentional task writer failure for testing purpose.")


          override def close(): Unit = {
            sys.error("Intentional task commitment failure for testing purpose.")

      override def getFileExtension(context: TaskAttemptContext): String = ""

  override def shortName(): String = "commit-failure-test"
Source File: ObjectAggregationMap.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.aggregate

import java.{util => ju}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.config
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

  def dumpToExternalSorter(
      groupingAttributes: Seq[Attribute],
      aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
    val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
    val sorter = new UnsafeKVExternalSorter(

    val mapIterator = iterator
    val unsafeAggBufferProjection =

    while (mapIterator.hasNext) {
      val entry =
      aggregateFunctions.foreach {
        case agg: TypedImperativeAggregate[_] =>
        case _ =>



  def clear(): Unit = {

// Stores the grouping key and aggregation buffer
class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow) 
Example 168
Source File: ShuffledHashJoinExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class ShuffledHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin {

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
    "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))

  override def requiredChildDistribution: Seq[Distribution] =
    HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil

  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
    val buildDataSize = longMetric("buildDataSize")
    val buildTime = longMetric("buildTime")
    val start = System.nanoTime()
    val context = TaskContext.get()
    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
    buildTime += (System.nanoTime() - start) / 1000000
    buildDataSize += relation.estimatedSize
    // This relation is usually used until the end of task.
    context.addTaskCompletionListener(_ => relation.close())

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    val avgHashProbe = longMetric("avgHashProbe")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows, avgHashProbe)
Source File: CodecStreams.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import{InputStream, OutputStream, OutputStreamWriter}
import java.nio.charset.{Charset, StandardCharsets}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils

import org.apache.spark.TaskContext

object CodecStreams {
  private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
    val compressionCodecs = new CompressionCodecFactory(config)

  def createInputStream(config: Configuration, file: Path): InputStream = {
    val fs = file.getFileSystem(config)
    val inputStream: InputStream =

    getDecompressionCodec(config, file)
      .map(codec => codec.createInputStream(inputStream))

  def getCompressionExtension(context: JobContext): String = {
Example 170
Source File: DataSourceRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory

class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
  extends Partition with Serializable

class DataSourceRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val readerFactories: java.util.List[DataReaderFactory[T]])
  extends RDD[T](sc, Nil) {

  override protected def getPartitions: Array[Partition] = { {
      case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
    context.addTaskCompletionListener(_ => reader.close())
    val iter = new Iterator[T] {
      private[this] var valuePrepared = false

      override def hasNext: Boolean = {
        if (!valuePrepared) {
          valuePrepared =

      override def next(): T = {
        if (!hasNext) {
          throw new java.util.NoSuchElementException("End of stream")
        valuePrepared = false
    new InterruptibleIterator(context, iter)

  override def getPreferredLocations(split: Partition): Seq[String] = {
Source File: FlatMapGroupsInPandasExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.StructType

case class FlatMapGroupsInPandasExec(
    groupingAttributes: Seq[Attribute],
    func: Expression,
    output: Seq[Attribute],
    child: SparkPlan)
  extends UnaryExecNode {

  private val pandasFunction = func.asInstanceOf[PythonUDF].func

  override def outputPartitioning: Partitioning = child.outputPartitioning

  override def producedAttributes: AttributeSet = AttributeSet(output)

  override def requiredChildDistribution: Seq[Distribution] = {
    if (groupingAttributes.isEmpty) {
      AllTuples :: Nil
    } else {
      ClusteredDistribution(groupingAttributes) :: Nil

  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
    Seq(, Ascending)))

  override protected def doExecute(): RDD[InternalRow] = {
    val inputRDD = child.execute()

    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
    val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
    val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
    val schema = StructType(child.schema.drop(groupingAttributes.length))
    val sessionLocalTimeZone = conf.sessionLocalTimeZone
    val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone

    inputRDD.mapPartitionsInternal { iter =>
      val grouped = if (groupingAttributes.isEmpty) {
      } else {
        val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
        val dropGrouping =
          UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) {
          case (_, groupedRowIter) =>

      val context = TaskContext.get()

      val columnarBatchIter = new ArrowPythonRunner(
        chainedFunc, bufferSize, reuseWorker,
        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
        sessionLocalTimeZone, pandasRespectSessionTimeZone)
          .compute(grouped, context.partitionId(), context)

      columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
Source File: ArrowEvalPythonExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.StructType

case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends EvalPythonExec(udfs, output, child) {

  private val batchSize = conf.arrowMaxRecordsPerBatch
  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
  private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone

  protected override def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      bufferSize: Int,
      reuseWorker: Boolean,
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow] = {

    val outputTypes = output.drop(child.output.length).map(_.dataType)

    // DO NOT use iter.grouped(). See BatchIterator.
    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)

    val columnarBatchIter = new ArrowPythonRunner(
        funcs, bufferSize, reuseWorker,
        PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema,
        sessionLocalTimeZone, pandasRespectSessionTimeZone)
      .compute(batchIter, context.partitionId(), context)

    new Iterator[InternalRow] {

      private var currentIter = if (columnarBatchIter.hasNext) {
        val batch =
        val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
        assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
          s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
      } else {

      override def hasNext: Boolean = currentIter.hasNext || {
        if (columnarBatchIter.hasNext) {
          currentIter =
        } else {

      override def next(): InternalRow =
Source File: BatchEvalPythonExec.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
  extends EvalPythonExec(udfs, output, child) {

  protected override def evaluate(
      funcs: Seq[ChainedPythonFunctions],
      bufferSize: Int,
      reuseWorker: Boolean,
      argOffsets: Array[Array[Int]],
      iter: Iterator[InternalRow],
      schema: StructType,
      context: TaskContext): Iterator[InternalRow] = {
    EvaluatePython.registerPicklers()  // register pickler for Row

    val dataTypes =
    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

    // enable memo iff we serialize the row with schema (schema and class should be memorized)
    val pickle = new Pickler(needConversion)
    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
    // For each row, add it to the queue.
    val inputIterator = { row =>
      if (needConversion) {
        EvaluatePython.toJava(row, schema)
      } else {
        // fast path for these types that does not need conversion in Python
        val fields = new Array[Any](row.numFields)
        var i = 0
        while (i < row.numFields) {
          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
          i += 1
    }.grouped(100).map(x => pickle.dumps(x.toArray))

    // Output iterator for results from Python.
    val outputIterator = new PythonUDFRunner(
        funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
      .compute(inputIterator, context.partitionId(), context)

    val unpickle = new Unpickler
    val mutableRow = new GenericInternalRow(1)
    val resultType = if (udfs.length == 1) {
    } else {
      StructType( => StructField("", u.dataType, u.nullable)))

    val fromJava = EvaluatePython.makeFromJava(resultType)

    outputIterator.flatMap { pickedResult =>
      val unpickledBatch = unpickle.loads(pickedResult)
    }.map { result =>
      if (udfs.length == 1) {
        // fast path for single UDF
        mutableRow(0) = fromJava(result)
      } else {
Example 174
Source File: StateStoreRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.state

import java.util.UUID

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val stateStoreProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),

  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
    var store: StateStore = null
    val storeProviderId = StateStoreProviderId(
      StateStoreId(checkpointLocation, operatorId, partition.index),

    store = StateStore.get(
      storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
      storeConf, hadoopConfBroadcast.value.value)
    val inputIter = dataRDD.iterator(partition, ctxt)
    storeUpdateFunction(store, inputIter)
Source File: package.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

package object state {

  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {

    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
        stateInfo: StatefulOperatorStateInfo,
        keySchema: StructType,
        valueSchema: StructType,
        indexOrdinal: Option[Int],
        sessionState: SessionState,
        storeCoordinator: Option[StateStoreCoordinatorRef])(
        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {

      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
        // Abort the state store in case of error
        TaskContext.get().addTaskCompletionListener(_ => {
          if (!store.hasCommitted) store.abort()
        cleanedF(store, iter)

Example 176
Source File: ReferenceSort.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryExecNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def outputPartitioning: Partitioning = child.outputPartitioning
Source File: SparkHadoopMapRedUtil.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mapred


import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging

object SparkHadoopMapRedUtil extends Logging {
  def commitTask(
      committer: MapReduceOutputCommitter,
      mrTaskContext: MapReduceTaskAttemptContext,
      jobId: Int,
      splitId: Int): Unit = {

    val mrTaskAttemptID = mrTaskContext.getTaskAttemptID

    // Called after we have decided to commit
    def performCommit(): Unit = {
      try {
        logInfo(s"$mrTaskAttemptID: Committed")
      } catch {
        case cause: IOException =>
          logError(s"Error committing the output of task: $mrTaskAttemptID", cause)
          throw cause

    // First, check whether the task's output has already been committed by some other attempt
    if (committer.needsTaskCommit(mrTaskContext)) {
      val shouldCoordinateWithDriver: Boolean = {
        val sparkConf = SparkEnv.get.conf
        // We only need to coordinate with the driver if there are concurrent task attempts.
        // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)

      if (shouldCoordinateWithDriver) {
        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
        val taskAttemptNumber = TaskContext.get().attemptNumber()
        val stageId = TaskContext.get().stageId()
        val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber)

        if (canCommit) {
        } else {
          val message =
            s"$mrTaskAttemptID: Not committed because the driver did not authorize commit"
          // We need to abort the task so that the driver can reschedule new attempts, if necessary
          throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber)
      } else {
        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
    } else {
      // Some other attempt committed the output, so we do nothing and signal success
      logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
Source File: taskListeners.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import java.util.EventListener

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi

class TaskCompletionListenerException(
    errorMessages: Seq[String],
    val previousError: Option[Throwable] = None)
  extends RuntimeException {

  override def getMessage: String = {
    val listenerErrorMessage =
      if (errorMessages.size == 1) {
      } else { { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
    val previousErrorMessage = { e =>
      "\n\nPrevious exception in task: " + e.getMessage + "\n" +
        e.getStackTrace.mkString("\t", "\n\t", "")
    listenerErrorMessage + previousErrorMessage
Source File: ZippedWithIndexRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    val parentIter = firstParent[T].iterator(split.prev, context)
    Utils.getIteratorZipWithIndex(parentIter, split.startIndex)
Source File: UnionRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
    } else {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
    rdds = null
Source File: PartitionwiseSampledRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler

class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
    prev: RDD[T],
    sampler: RandomSampler[T, U],
    preservesPartitioning: Boolean,
    @transient private val seed: Long = Utils.random.nextLong)
  extends RDD[U](prev) {

  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None

  override def getPartitions: Array[Partition] = {
    val random = new Random(seed)
    firstParent[T] => new PartitionwiseSampledRDDPartition(x, random.nextLong()))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
    val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
    val thisSampler = sampler.clone
    thisSampler.sample(firstParent[T].iterator(split.prev, context))
Source File: PartitionerAwareUnionRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.reflect.ClassTag

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
import org.apache.spark.util.Utils

class PartitionerAwareUnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]]
  ) extends RDD[T](sc, => new OneToOneDependency(x))) {
  require(rdds.flatMap(_.partitioner).toSet.size == 1,
    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))

  override val partitioner = rdds.head.partitioner

  override def getPartitions: Array[Partition] = {
    val numPartitions = partitioner.get.numPartitions
    (0 until numPartitions).map { index =>
      new PartitionerAwareUnionRDDPartition(rdds, index)

  // Get the location where most of the partitions of parent RDDs are located
  override def getPreferredLocations(s: Partition): Seq[String] = {
    logDebug("Finding preferred location for " + this + ", partition " + s.index)
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
    val locations = {
      case (rdd, part) =>
        val parentLocations = currPrefLocs(rdd, part)
        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
    val location = if (locations.isEmpty) {
    } else {
      // Find the location that maximum number of parent partitions prefer
      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents {
      case (rdd, p) => rdd.iterator(p, context)

  override def clearDependencies() {
    rdds = null

  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
    rdd.context.getPreferredLocs(rdd, part.index).map(tl =>
Source File: MemoryTestingUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.memory

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}

object MemoryTestingUtils {
  def fakeTaskContext(env: SparkEnv): TaskContext = {
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
    new TaskContextImpl(
      stageId = 0,
      stageAttemptNumber = 0,
      partitionId = 0,
      taskAttemptId = 0,
      attemptNumber = 0,
      taskMemoryManager = taskMemoryManager,
      localProperties = new Properties,
      metricsSystem = env.metricsSystem)
Source File: FakeTask.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.executor.TaskMetrics

class FakeTask(
    stageId: Int,
    partitionId: Int,
    prefLocs: Seq[TaskLocation] = Nil,
    serializedTaskMetrics: Array[Byte] =
  extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) {

  override def runTask(context: TaskContext): Int = 0
  override def preferredLocations: Seq[TaskLocation] = prefLocs

object FakeTask {
  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
    createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)

  def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
  TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
    new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)

  def createShuffleMapTaskSet(
      numTasks: Int,
      stageId: Int,
      stageAttemptId: Int,
      prefLocs: Seq[TaskLocation]*): TaskSet = {
    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
      throw new IllegalArgumentException("Wrong number of task locations")
    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
      new ShuffleMapTask(stageId, stageAttemptId, null, new Partition {
        override def index: Int = i
      }, prefLocs(i), new Properties,
    new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
Source File: OutputCommitCoordinatorIntegrationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.{Seconds, Span}

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
import org.apache.spark.util.Utils

class OutputCommitCoordinatorIntegrationSuite
  extends SparkFunSuite
  with LocalSparkContext
  with TimeLimits {

  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
  implicit val defaultSignaler: Signaler = ThreadSignaler

  override def beforeAll(): Unit = {
    val conf = new SparkConf()
      .set("spark.hadoop.outputCommitCoordination.enabled", "true")
    sc = new SparkContext("local[2, 4]", "test", conf)

  test("exception thrown in OutputCommitter.commitTask()") {
    // Regression test for SPARK-10381
    failAfter(Span(60, Seconds)) {
      val tempDir = Utils.createTempDir()
      try {
        sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out")
      } finally {

private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter {
  override def commitTask(context: TaskAttemptContext): Unit = {
    val ctx = TaskContext.get()
    if (ctx.attemptNumber < 1) {
      throw new"Intentional exception")
Source File: PartitionPruningRDDSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext}

class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext {

  test("Pruned Partitions inherit locality prefs correctly") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 1),
          new TestPartition(1, 1),
          new TestPartition(2, 1))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
    assert(prunedRDD.partitions.length == 1)
    val p = prunedRDD.partitions(0)
    assert(p.index == 0)
    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)

  test("Pruned Partitions can be unioned ") {

    val rdd = new RDD[Int](sc, Nil) {
      override protected def getPartitions = {
          new TestPartition(0, 4),
          new TestPartition(1, 5),
          new TestPartition(2, 6))

      def compute(split: Partition, context: TaskContext) = {
    val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)

    val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)

    val merged = prunedRDD1 ++ prunedRDD2
    assert(merged.count() == 2)
    val take = merged.take(2)
    assert(take.apply(0) == 4)
    assert(take.apply(1) == 6)

class TestPartition(i: Int, value: Int) extends Partition with Serializable {
  def index: Int = i
  def testValue: Int = this.value
Source File: StarryRDD.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SparkContext, TaskContext}

import scala.reflect.ClassTag

class StarryRDD[T: ClassTag](sc: SparkContext,
                             rddName: String,
                             @transient private var data: Seq[T]
                            ) extends RDD[T](sc, Nil) {
  def this (sc: SparkContext, data: Seq[T]) = {
    this (sc, getClass.getSimpleName, data)


  override def compute(split: Partition, context: TaskContext): Iterator[T] = {

  def updateData(data: Seq[T]): Unit = { = data

    Array(new ParallelCollectionPartition(id, 0, data))
Example 188
Source File: HashSetManager.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up

import edu.ucla.cs.wis.bigdatalog.spark.SchemaInfo
import org.apache.spark.TaskContext
import org.apache.spark.sql.types.{IntegerType, LongType}

object HashSetManager {
  def determineKeyType(schemaInfo: SchemaInfo): Int = {
    schemaInfo.arity match {
      case 1 => {
        schemaInfo.schema(0).dataType match {
          case IntegerType => 1
          case LongType => 2
          case other => 3
      case 2 => {
        val bytesPerKey =
        if (bytesPerKey == 8) 2 else 3
      case other => 3

  def create(schemaInfo: SchemaInfo): HashSet = {
    determineKeyType(schemaInfo) match {
      case 1 => new IntKeysHashSet()
      case 2 => new LongKeysHashSet(schemaInfo)
      case _ => new ObjectHashSet()
Source File: SlidingRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.rdd

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{TaskContext, Partition}
import org.apache.spark.rdd.RDD

class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
  extends Partition with Serializable {
  override val index: Int = idx

class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
  extends RDD[Array[T]](parent) {

  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
    "Window size and step must be greater than 0, " +
      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")

  override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
    val part = split.asInstanceOf[SlidingRDDPartition[T]]
    (firstParent[T].iterator(part.prev, context) ++ part.tail)
      .sliding(windowSize, step)

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def getPartitions: Array[Partition] = {
    val parentPartitions = parent.partitions
    val n = parentPartitions.length
    if (n == 0) {
    } else if (n == 1) {
      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
    } else {
      val w1 = windowSize - 1
      // Get partition sizes and first w1 elements.
      val (sizes, heads) = parent.mapPartitions { iter =>
        val w1Array = iter.take(w1).toArray
        Iterator.single((w1Array.length + iter.length, w1Array))
      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
      var i = 0
      var cumSize = 0
      var partitionIndex = 0
      while (i < n) {
        val mod = cumSize % step
        val offset = if (mod == 0) 0 else step - mod
        val size = sizes(i)
        if (offset < size) {
          val tail = mutable.ListBuffer.empty[T]
          // Keep appending to the current tail until it has w1 elements.
          var j = i + 1
          while (j < n && tail.length < w1) {
            tail ++= heads(j).take(w1 - tail.length)
            j += 1
          if (sizes(i) + tail.length >= offset + windowSize) {
            partitions +=
              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
            partitionIndex += 1
        cumSize += size
        i += 1

  // TODO: Override methods such as aggregate, which only requires one Spark job.
Source File: MonotonicallyIncreasingID.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{LongType, DataType}

  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initInternal(): Unit = {
    count = 0L
    partitionMask = TaskContext.getPartitionId().toLong << 33

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")

    ev.isNull = "false"
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
Source File: randomExpressions.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

case class Randn(seed: Long) extends RDG {
  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  def this() = this(Utils.random.nextLong())

  def this(seed: Expression) = this(seed match {
    case IntegerLiteral(s) => s
    case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val rngTerm = ctx.freshName("rng")
    val className = classOf[XORShiftRandom].getName
    ctx.addMutableState(className, rngTerm,
      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
    ev.isNull = "false"
      final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();
Source File: BroadcastLeftSemiJoinHash.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class BroadcastLeftSemiJoinHash(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    left: SparkPlan,
    right: SparkPlan,
    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

  override private[sql] lazy val metrics = Map(
    "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
    "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numLeftRows = longMetric("numLeftRows")
    val numRightRows = longMetric("numRightRows")
    val numOutputRows = longMetric("numOutputRows")

    val input = right.execute().map { row =>
      numRightRows += 1

    if (condition.isEmpty) {
      val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric)
      val broadcastedRelation = sparkContext.broadcast(hashSet)

      left.execute().mapPartitionsInternal { streamIter =>
        hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows)
    } else {
      val hashRelation =
        HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size)
      val broadcastedRelation = sparkContext.broadcast(hashRelation)

      left.execute().mapPartitionsInternal { streamIter =>
        val hashedRelation = broadcastedRelation.value
        hashedRelation match {
          case unsafe: UnsafeHashedRelation =>
          case _ =>
        hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
Source File: Sort.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics

case class Sort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan,
    testSpillFrequency: Int = 0)
  extends UnaryNode {

  override def outputsUnsafeRows: Boolean = true
  override def canProcessUnsafeRows: Boolean = true
  override def canProcessSafeRows: Boolean = false

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  override private[sql] lazy val metrics = Map(
    "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

  protected override def doExecute(): RDD[InternalRow] = {
    val schema = child.schema
    val childOutput = child.output

    val dataSize = longMetric("dataSize")
    val spillSize = longMetric("spillSize")

    child.execute().mapPartitionsInternal { iter =>
      val ordering = newOrdering(sortOrder, childOutput)

      // The comparator for comparing prefix
      val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
      val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

      // The generator for prefix
      val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
      val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
        override def computePrefix(row: InternalRow): Long = {

      val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
      val sorter = new UnsafeExternalRowSorter(
        schema, ordering, prefixComparator, prefixComputer, pageSize)
      if (testSpillFrequency > 0) {

      // Remember spill data size of this task before execute this operator so that we can
      // figure out how many bytes we spilled for this operator.
      val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled

      val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])

      dataSize += sorter.getPeakMemoryUsage
      spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore

Source File: ReferenceSort.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

case class ReferenceSort(
    sortOrder: Seq[SortOrder],
    global: Boolean,
    child: SparkPlan)
  extends UnaryNode {

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
    child.execute().mapPartitions( { iterator =>
      val ordering = newOrdering(sortOrder, child.output)
      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
        TaskContext.get(), ordering = Some(ordering))
      sorter.insertAll( => (r.copy(), null)))
      val baseIterator =
      val context = TaskContext.get()
      CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
    }, preservesPartitioning = true)

  override def output: Seq[Attribute] = child.output

  override def outputOrdering: Seq[SortOrder] = sortOrder
Source File: FixedPointJobDefinition.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.fixedpoint

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD

import scala.collection.mutable.{HashSet, HashMap, Set}

class FixedPointJobDefinition(val setupIteration: (FixedPointJobDefinition, RDD[_]) => RDD[_],
                              val cleanupIteration: (Int) => Unit) {
  var _fixedPointEvaluator: (TaskContext, Iterator[_]) => Boolean = null
  var finalRDD: RDD[_] = null
  var rddIds = Array.empty[Int] // for all and delta rdd id for FixedPointResultTask execution on worker

  def fixedPointEvaluator(fixedPointEvaluator: (TaskContext, Iterator[_]) => Boolean) = {
    _fixedPointEvaluator = fixedPointEvaluator

  def getfixedPointEvaluator = _fixedPointEvaluator.asInstanceOf[(TaskContext, Iterator[_]) => _]

  def getFinalRDD: RDD[_] = finalRDD

  def setRDDIds(newAllRDDId: Int,
                oldAllRDDId: Int,
                newDeltaPrimeRDDId: Int,
                oldDeltaPrimeRDDId: Int): Unit = {

    rddIds = Array(newAllRDDId, oldAllRDDId, newDeltaPrimeRDDId, oldDeltaPrimeRDDId)
Source File: FixedPointResultStage.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.fixedpoint

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{Stage, ResultStage}
import org.apache.spark.util.CallSite

class FixedPointResultStage(id: Int,
                            rdd: RDD[_],
                            override val func: (TaskContext, Iterator[_]) => _,
                            override val partitions: Array[Int],
                            parents: List[Stage],
                            jobId: Int,
                            callSite: CallSite)
  extends ResultStage(id, rdd, func, partitions, parents, jobId, callSite) {

  def hasParent = parents.nonEmpty

  override def toString: String = "FixedPointResultStage " + id
Source File: SampledRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.{Partition, TaskContext}

@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
  override val index: Int = prev.index

@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
private[spark] class SampledRDD[T: ClassTag](
    prev: RDD[T],
    withReplacement: Boolean,
    frac: Double,
    seed: Int)
  extends RDD[T](prev) {

  override def getPartitions: Array[Partition] = {
    val rg = new Random(seed)
    firstParent[T] => new SampledRDDPartition(x, rg.nextInt))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
    val split = splitIn.asInstanceOf[SampledRDDPartition]
    if (withReplacement) {
      // For large datasets, the expected number of occurrences of each element in a sample with
      // replacement is Poisson(frac). We use that to get a count for each element.
      val poisson = new PoissonDistribution(frac)

      firstParent[T].iterator(split.prev, context).flatMap { element =>
        val count = poisson.sample()
        if (count == 0) {
          Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
        } else {
    } else { // Sampling without replacement
      val rand = new Random(split.seed)
      firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
Source File: ZippedWithIndexRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.Utils

class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
  extends Partition with Serializable {
  override val index: Int = prev.index

  @transient private val startIndices: Array[Long] = {
    val n = prev.partitions.length
    if (n == 0) {
    } else if (n == 1) {
    } else {
        Utils.getIteratorSize _,
        0 until n - 1 // do not need to count the last partition
      ).scanLeft(0L)(_ + _)

  override def getPartitions: Array[Partition] = {
    firstParent[T] => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))

  override def getPreferredLocations(split: Partition): Seq[String] =

  override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
    val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
    firstParent[T].iterator(split.prev, context) { x =>
      (x._1, split.startIndex + x._2)
Source File: MemoryCheckpointRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext}

import scala.reflect.ClassTag

// We use a different class than LocalCheckpointRDD, but the same functionality,
// so that we easily identify (e..g, pattern match) this class in DAGScheduler.
class MemoryCheckpointRDD[T: ClassTag](sc: SparkContext, rddId: Int, numPartitions: Int)
  extends LocalCheckpointRDD[T](sc, rddId, numPartitions) {

  def this(rdd: RDD[T]) {
    this(rdd.context,, rdd.partitions.size)

  override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
    throw new SparkException(
      s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
        s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
        s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " +
        s"or `rdd.localcheckpoint()` instead, which are slower than memory checkpointing but more fault-tolerant.")
Source File: UnionRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)

class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)

  override def getPreferredLocations(s: Partition): Seq[String] =

  override def clearDependencies() {
