org.apache.spark.util.ThreadUtils Scala Examples

The following examples show how to use org.apache.spark.util.ThreadUtils. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RateController.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 2
Source File: LauncherBackend.scala    From drizzle-spark   with Apache License 2.0 6 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 3
Source File: ExecutorDelegationTokenUpdater.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.util.concurrent.{Executors, TimeUnit}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.{Credentials, UserGroupInformation}

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.{ThreadUtils, Utils}

import scala.util.control.NonFatal

private[spark] class ExecutorDelegationTokenUpdater(
    sparkConf: SparkConf,
    hadoopConf: Configuration) extends Logging {

  @volatile private var lastCredentialsFileSuffix = 0

  private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")

  private val delegationTokenRenewer =
    Executors.newSingleThreadScheduledExecutor(
      ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread"))

  // On the executor, this thread wakes up and picks up new tokens from HDFS, if any.
  private val executorUpdaterRunnable =
    new Runnable {
      override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired())
    }

  def updateCredentialsIfRequired(): Unit = {
    try {
      val credentialsFilePath = new Path(credentialsFile)
      val remoteFs = FileSystem.get(hadoopConf)
      SparkHadoopUtil.get.listFilesSorted(
        remoteFs, credentialsFilePath.getParent,
        credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
        .lastOption.foreach { credentialsStatus =>
        val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath)
        if (suffix > lastCredentialsFileSuffix) {
          logInfo("Reading new delegation tokens from " + credentialsStatus.getPath)
          val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath)
          lastCredentialsFileSuffix = suffix
          UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
          logInfo("Tokens updated from credentials file.")
        } else {
          // Check every hour to see if new credentials arrived.
          logInfo("Updated delegation tokens were expected, but the driver has not updated the " +
            "tokens yet, will check again in an hour.")
          delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
          return
        }
      }
      val timeFromNowToRenewal =
        SparkHadoopUtil.get.getTimeFromNowToRenewal(
          sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials)
      if (timeFromNowToRenewal <= 0) {
        executorUpdaterRunnable.run()
      } else {
        logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.")
        delegationTokenRenewer.schedule(
          executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS)
      }
    } catch {
      // Since the file may get deleted while we are reading it, catch the Exception and come
      // back in an hour to try again
      case NonFatal(e) =>
        logWarning("Error while trying to update credentials, will try again in 1 hour", e)
        delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
    }
  }

  private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = {
    val stream = remoteFs.open(tokenPath)
    try {
      val newCredentials = new Credentials()
      newCredentials.readTokenStorageStream(stream)
      newCredentials
    } finally {
      stream.close()
    }
  }

  def stop(): Unit = {
    delegationTokenRenewer.shutdown()
  }

} 
Example 4
Source File: BlockManagerSlaveEndpoint.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.senderAddress)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 5
Source File: BlockManagerSlaveEndpoint.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends RpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.sender)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 6
Source File: ExecutorDelegationTokenUpdater.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.util.concurrent.{Executors, TimeUnit}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.{Credentials, UserGroupInformation}

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.{ThreadUtils, Utils}

import scala.util.control.NonFatal

private[spark] class ExecutorDelegationTokenUpdater(
    sparkConf: SparkConf,
    hadoopConf: Configuration) extends Logging {

  @volatile private var lastCredentialsFileSuffix = 0

  private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
  private val freshHadoopConf =
    SparkHadoopUtil.get.getConfBypassingFSCache(
      hadoopConf, new Path(credentialsFile).toUri.getScheme)

  private val delegationTokenRenewer =
    Executors.newSingleThreadScheduledExecutor(
      ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread"))

  // On the executor, this thread wakes up and picks up new tokens from HDFS, if any.
  //在执行程序中,该线程唤醒并从HDFS中获取新令牌(如果有的话)
  private val executorUpdaterRunnable =
    new Runnable {
      override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired())
    }

  def updateCredentialsIfRequired(): Unit = {
    try {
      val credentialsFilePath = new Path(credentialsFile)
      val remoteFs = FileSystem.get(freshHadoopConf)
      SparkHadoopUtil.get.listFilesSorted(
        remoteFs, credentialsFilePath.getParent,
        credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
        .lastOption.foreach { credentialsStatus =>
        val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath)
        if (suffix > lastCredentialsFileSuffix) {
          logInfo("Reading new delegation tokens from " + credentialsStatus.getPath)
          val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath)
          lastCredentialsFileSuffix = suffix
          UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
          logInfo("Tokens updated from credentials file.")
        } else {
          // Check every hour to see if new credentials arrived.
          logInfo("Updated delegation tokens were expected, but the driver has not updated the " +
            "tokens yet, will check again in an hour.")
          delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
          return
        }
      }
      val timeFromNowToRenewal =
        SparkHadoopUtil.get.getTimeFromNowToRenewal(
          sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials)
      if (timeFromNowToRenewal <= 0) {
        executorUpdaterRunnable.run()
      } else {
        logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.")
        delegationTokenRenewer.schedule(
          executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS)
      }
    } catch {
      // Since the file may get deleted while we are reading it, catch the Exception and come
      // back in an hour to try again
      case NonFatal(e) =>
        logWarning("Error while trying to update credentials, will try again in 1 hour", e)
        delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
    }
  }

  private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = {
    val stream = remoteFs.open(tokenPath)
    try {
      val newCredentials = new Credentials()
      newCredentials.readTokenStorageStream(stream)
      newCredentials
    } finally {
      stream.close()
    }
  }

  def stop(): Unit = {
    delegationTokenRenewer.shutdown()
  }

} 
Example 7
Source File: RateController.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 8
Source File: BlockManagerSlaveEndpoint.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ ExecutionContext, Future }

import org.apache.spark.rpc.{ RpcEnv, RpcCallContext, RpcEndpoint }
import org.apache.spark.util.ThreadUtils
import org.apache.spark.{ Logging, MapOutputTracker, SparkEnv }
import org.apache.spark.storage.BlockManagerMessages._


private[storage] class BlockManagerSlaveEndpoint(
  override val rpcEnv: RpcEnv,
  blockManager: BlockManager,//引用BlockManagerMaster与Mast消息通信
  mapOutputTracker: MapOutputTracker)
    extends RpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  //涉及删除块的操作可能很慢,应该是异步完成
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    //根据BlockId删除该Executor上所有和该Shuffle相关的Block
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }     
    //收到BlockManagerMasterEndpoint发送RemoveRdd信息,根据RddId删除该Excutor上RDD所关联的所有Block
    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }
    //根据shuffleId删除该Executor上所有和该Shuffle相关的Block
    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }
    //根据broadcastId删除该Executor上和该广播变量相关的所有Block
    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        //tellMaster 是否将状态汇报到Master
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }
    //根据blockId和askSlaves向Master返回该Block的blockStatus
    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))
    //根据blockId和askSlaves向Master返回该Block的blockStatus
    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))
  }
  //科里化函数,异步调用,方法回调
  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess {
      case response =>
        logDebug("Done " + actionMessage + ", response is " + response)
        context.reply(response)
        logDebug("Sent response: " + response + " to " + context.sender)
    }
    future.onFailure {
      case t: Throwable =>
        logError("Error in " + actionMessage, t)
        context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 9
Source File: HasParallelism.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.param.shared

import scala.concurrent.ExecutionContext

import org.apache.spark.ml.param.{IntParam, Params, ParamValidators}
import org.apache.spark.util.ThreadUtils


  private[ml] def getExecutionContext: ExecutionContext = {
    getParallelism match {
      case 1 =>
        ThreadUtils.sameThread
      case n =>
        ExecutionContext.fromExecutorService(ThreadUtils
          .newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", n))
    }
  }
} 
Example 10
Source File: SparkPodInitContainer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s

import java.io.File
import java.util.concurrent.TimeUnit

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.internal.Logging
import org.apache.spark.util.{ThreadUtils, Utils}


private[spark] class SparkPodInitContainer(
    sparkConf: SparkConf,
    fileFetcher: FileFetcher) extends Logging {

  private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE)
  private implicit val downloadExecutor = ExecutionContext.fromExecutorService(
    ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize))

  private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION))
  private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION))

  private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS)
  private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES)

  private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT)

  def run(): Unit = {
    logInfo(s"Downloading remote jars: $remoteJars")
    downloadFiles(
      remoteJars,
      jarsDownloadDir,
      s"Remote jars download directory specified at $jarsDownloadDir does not exist " +
        "or is not a directory.")

    logInfo(s"Downloading remote files: $remoteFiles")
    downloadFiles(
      remoteFiles,
      filesDownloadDir,
      s"Remote files download directory specified at $filesDownloadDir does not exist " +
        "or is not a directory.")

    downloadExecutor.shutdown()
    downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES)
  }

  private def downloadFiles(
      filesCommaSeparated: Option[String],
      downloadDir: File,
      errMessage: String): Unit = {
    filesCommaSeparated.foreach { files =>
      require(downloadDir.isDirectory, errMessage)
      Utils.stringToSeq(files).foreach { file =>
        Future[Unit] {
          fileFetcher.fetchFile(file, downloadDir)
        }
      }
    }
  }
}

private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) {

  def fetchFile(uri: String, targetDir: File): Unit = {
    Utils.fetchFile(
      url = uri,
      targetDir = targetDir,
      conf = sparkConf,
      securityMgr = securityManager,
      hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf),
      timestamp = System.currentTimeMillis(),
      useCache = false)
  }
}

object SparkPodInitContainer extends Logging {

  def main(args: Array[String]): Unit = {
    logInfo("Starting init-container to download Spark application dependencies.")
    val sparkConf = new SparkConf(true)
    if (args.nonEmpty) {
      Utils.loadDefaultSparkProperties(sparkConf, args(0))
    }

    val securityManager = new SparkSecurityManager(sparkConf)
    val fileFetcher = new FileFetcher(sparkConf, securityManager)
    new SparkPodInitContainer(sparkConf, fileFetcher).run()
    logInfo("Finished downloading application dependencies.")
  }
} 
Example 11
Source File: SparkKubernetesClientFactory.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.k8s

import java.io.File

import com.google.common.base.Charsets
import com.google.common.io.Files
import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient}
import io.fabric8.kubernetes.client.utils.HttpClientUtils
import okhttp3.Dispatcher

import org.apache.spark.SparkConf
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.util.ThreadUtils


private[spark] object SparkKubernetesClientFactory {

  def createKubernetesClient(
      master: String,
      namespace: Option[String],
      kubernetesAuthConfPrefix: String,
      sparkConf: SparkConf,
      defaultServiceAccountToken: Option[File],
      defaultServiceAccountCaCert: Option[File]): KubernetesClient = {
    val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX"
    val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX"
    val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf)
      .map(new File(_))
      .orElse(defaultServiceAccountToken)
    val oauthTokenValue = sparkConf.getOption(oauthTokenConf)
    KubernetesUtils.requireNandDefined(
      oauthTokenFile,
      oauthTokenValue,
      s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " +
        s"value $oauthTokenConf.")

    val caCertFile = sparkConf
      .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX")
      .orElse(defaultServiceAccountCaCert.map(_.getAbsolutePath))
    val clientKeyFile = sparkConf
      .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX")
    val clientCertFile = sparkConf
      .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX")
    val dispatcher = new Dispatcher(
      ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher"))
    val config = new ConfigBuilder()
      .withApiVersion("v1")
      .withMasterUrl(master)
      .withWebsocketPingInterval(0)
      .withOption(oauthTokenValue) {
        (token, configBuilder) => configBuilder.withOauthToken(token)
      }.withOption(oauthTokenFile) {
        (file, configBuilder) =>
            configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8))
      }.withOption(caCertFile) {
        (file, configBuilder) => configBuilder.withCaCertFile(file)
      }.withOption(clientKeyFile) {
        (file, configBuilder) => configBuilder.withClientKeyFile(file)
      }.withOption(clientCertFile) {
        (file, configBuilder) => configBuilder.withClientCertFile(file)
      }.withOption(namespace) {
        (ns, configBuilder) => configBuilder.withNamespace(ns)
      }.build()
    val baseHttpClient = HttpClientUtils.createHttpClient(config)
    val httpClientWithCustomDispatcher = baseHttpClient.newBuilder()
      .dispatcher(dispatcher)
      .build()
    new DefaultKubernetesClient(httpClientWithCustomDispatcher, config)
  }

  private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder)
    extends AnyVal {

    def withOption[T]
        (option: Option[T])
        (configurator: ((T, ConfigBuilder) => ConfigBuilder)): ConfigBuilder = {
      option.map { opt =>
        configurator(opt, configBuilder)
      }.getOrElse(configBuilder)
    }
  }
} 
Example 12
Source File: RateController.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 13
Source File: BlockTransferService.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network

import java.io.Closeable
import java.nio.ByteBuffer

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils

private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {

  
  def uploadBlockSync(
      hostname: String,
      port: Int,
      execId: String,
      blockId: BlockId,
      blockData: ManagedBuffer,
      level: StorageLevel,
      classTag: ClassTag[_]): Unit = {
    val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
    ThreadUtils.awaitResult(future, Duration.Inf)
  }
} 
Example 14
Source File: LauncherBackend.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.{SPARK_VERSION, SparkConf}
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 15
Source File: BlockManagerSlaveEndpoint.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())

    case ReplicateBlock(blockId, replicas, maxReplicas) =>
      context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas))

  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.foreach { response =>
      logDebug(s"Done $actionMessage, response is $response")
      context.reply(response)
      logDebug(s"Sent response: $response to ${context.senderAddress}")
    }
    future.failed.foreach { t =>
      logError(s"Error in $actionMessage", t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 16
Source File: FutureActionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 17
Source File: ExecutorDelegationTokenUpdater.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import java.util.concurrent.{Executors, TimeUnit}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.{Credentials, UserGroupInformation}

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.{ThreadUtils, Utils}

import scala.util.control.NonFatal

private[spark] class ExecutorDelegationTokenUpdater(
    sparkConf: SparkConf,
    hadoopConf: Configuration) extends Logging {

  @volatile private var lastCredentialsFileSuffix = 0

  private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
  private val freshHadoopConf =
    SparkHadoopUtil.get.getConfBypassingFSCache(
      hadoopConf, new Path(credentialsFile).toUri.getScheme)

  private val delegationTokenRenewer =
    Executors.newSingleThreadScheduledExecutor(
      ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread"))

  // On the executor, this thread wakes up and picks up new tokens from HDFS, if any.
  private val executorUpdaterRunnable =
    new Runnable {
      override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired())
    }

  def updateCredentialsIfRequired(): Unit = {
    try {
      val credentialsFilePath = new Path(credentialsFile)
      val remoteFs = FileSystem.get(freshHadoopConf)
      SparkHadoopUtil.get.listFilesSorted(
        remoteFs, credentialsFilePath.getParent,
        credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
        .lastOption.foreach { credentialsStatus =>
        val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath)
        if (suffix > lastCredentialsFileSuffix) {
          logInfo("Reading new delegation tokens from " + credentialsStatus.getPath)
          val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath)
          lastCredentialsFileSuffix = suffix
          UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
          logInfo("Tokens updated from credentials file.")
        } else {
          // Check every hour to see if new credentials arrived.
          logInfo("Updated delegation tokens were expected, but the driver has not updated the " +
            "tokens yet, will check again in an hour.")
          delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
          return
        }
      }
      val timeFromNowToRenewal =
        SparkHadoopUtil.get.getTimeFromNowToRenewal(
          sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials)
      if (timeFromNowToRenewal <= 0) {
        // We just checked for new credentials but none were there, wait a minute and retry.
        // This handles the shutdown case where the staging directory may have been removed(see
        // SPARK-12316 for more details).
        delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.MINUTES)
      } else {
        logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.")
        delegationTokenRenewer.schedule(
          executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS)
      }
    } catch {
      // Since the file may get deleted while we are reading it, catch the Exception and come
      // back in an hour to try again
      case NonFatal(e) =>
        logWarning("Error while trying to update credentials, will try again in 1 hour", e)
        delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS)
    }
  }

  private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = {
    val stream = remoteFs.open(tokenPath)
    try {
      val newCredentials = new Credentials()
      newCredentials.readTokenStorageStream(stream)
      newCredentials
    } finally {
      stream.close()
    }
  }

  def stop(): Unit = {
    delegationTokenRenewer.shutdown()
  }

} 
Example 18
Source File: RateController.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 19
Source File: LauncherBackend.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 20
Source File: FutureActionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 21
Source File: BlockTransferService.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network

import java.io.Closeable
import java.nio.ByteBuffer

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils

private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {

  
  def uploadBlockSync(
      hostname: String,
      port: Int,
      execId: String,
      blockId: BlockId,
      blockData: ManagedBuffer,
      level: StorageLevel,
      classTag: ClassTag[_]): Unit = {
    val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
    ThreadUtils.awaitResult(future, Duration.Inf)
  }
} 
Example 22
Source File: BlockManagerSlaveEndpoint.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.senderAddress)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 23
Source File: FutureActionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 24
Source File: RPCContinuousShuffleWriter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous.shuffle

import scala.concurrent.Future
import scala.concurrent.duration.Duration

import org.apache.spark.Partitioner
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.util.ThreadUtils


class RPCContinuousShuffleWriter(
    writerId: Int,
    outputPartitioner: Partitioner,
    endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {

  if (outputPartitioner.numPartitions != 1) {
    throw new IllegalArgumentException("multiple readers not yet supported")
  }

  if (outputPartitioner.numPartitions != endpoints.length) {
    throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " +
      s"not match endpoint count ${endpoints.length}")
  }

  def write(epoch: Iterator[UnsafeRow]): Unit = {
    while (epoch.hasNext) {
      val row = epoch.next()
      endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row))
    }

    val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq
    implicit val ec = ThreadUtils.sameThread
    ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf)
  }
} 
Example 25
Source File: OapRpcManagerSlave.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.oap.rpc

import java.util.concurrent.TimeUnit

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.execution.datasources.oap.filecache.{CacheStats, FiberCacheManager}
import org.apache.spark.sql.internal.oap.OapConf
import org.apache.spark.sql.oap.adapter.RpcEndpointRefAdapter
import org.apache.spark.sql.oap.rpc.OapMessages._
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.{ThreadUtils, Utils}


private[spark] class OapRpcManagerSlave(
    rpcEnv: RpcEnv,
    val driverEndpoint: RpcEndpointRef,
    executorId: String,
    blockManager: BlockManager,
    fiberCacheManager: FiberCacheManager,
    conf: SparkConf) extends OapRpcManager {

  // Send OapHeartbeatMessage to Driver timed
  private val oapHeartbeater =
    ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater")

  private val slaveEndpoint = rpcEnv.setupEndpoint(
    s"OapRpcManagerSlave_$executorId", new OapRpcManagerSlaveEndpoint(rpcEnv, fiberCacheManager))

  initialize()
  startOapHeartbeater()

  protected def heartbeatMessages: Array[() => Heartbeat] = {
    Array(
      () => FiberCacheHeartbeat(
        executorId, blockManager.blockManagerId, fiberCacheManager.status()),
      () => FiberCacheMetricsHeartbeat(executorId, blockManager.blockManagerId,
        CacheStats.status(fiberCacheManager.cacheStats, conf)))
  }

  private def initialize() = {
    RpcEndpointRefAdapter.askSync[Boolean](
      driverEndpoint, RegisterOapRpcManager(executorId, slaveEndpoint))
  }

  override private[spark] def send(message: OapMessage): Unit = {
    driverEndpoint.send(message)
  }

  private[sql] def startOapHeartbeater(): Unit = {

    def reportHeartbeat(): Unit = {
      // OapRpcManagerSlave is created in SparkEnv. Before we start the heartbeat, we need make
      // sure the SparkEnv has been created and the block manager has been initialized. We check
      // blockManagerId as it will be set after initialization.
      if (blockManager.blockManagerId != null) {
        heartbeatMessages.map(_.apply()).foreach(send)
      }
    }

    val intervalMs = conf.getTimeAsMs(
      OapConf.OAP_HEARTBEAT_INTERVAL.key, OapConf.OAP_HEARTBEAT_INTERVAL.defaultValue.get)

    // Wait a random interval so the heartbeats don't end up in sync
    val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int]

    val heartbeatTask = new Runnable() {
      override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat())
    }
    oapHeartbeater.scheduleAtFixedRate(
      heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
  }

  override private[spark] def stop(): Unit = {
    oapHeartbeater.shutdown()
  }
}

private[spark] class OapRpcManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv, fiberCacheManager: FiberCacheManager)
  extends ThreadSafeRpcEndpoint with Logging {

  override def receive: PartialFunction[Any, Unit] = {
    case message: OapMessage => handleOapMessage(message)
    case _ =>
  }

  private def handleOapMessage(message: OapMessage): Unit = message match {
    case CacheDrop(indexName) => fiberCacheManager.releaseIndexCache(indexName)
    case _ =>
  }
} 
Example 26
Source File: RateController.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 27
Source File: BlockTransferService.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network

import java.io.Closeable
import java.nio.ByteBuffer

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils

private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {

  
  def uploadBlockSync(
      hostname: String,
      port: Int,
      execId: String,
      blockId: BlockId,
      blockData: ManagedBuffer,
      level: StorageLevel,
      classTag: ClassTag[_]): Unit = {
    val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
    ThreadUtils.awaitResult(future, Duration.Inf)
  }
} 
Example 28
Source File: LauncherBackend.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 29
Source File: BlockManagerSlaveEndpoint.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.senderAddress)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 30
Source File: FutureActionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 31
Source File: RateController.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkConf
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.{ThreadUtils, Utils}


  private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
    Future[Unit] {
      val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
      newRate.foreach { s =>
        rateLimit.set(s.toLong)
        publish(getLatestRate())
      }
    }

  def getLatestRate(): Long = rateLimit.get()

  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
    val elements = batchCompleted.batchInfo.streamIdToInputInfo

    for {
      processingEnd <- batchCompleted.batchInfo.processingEndTime
      workDelay <- batchCompleted.batchInfo.processingDelay
      waitDelay <- batchCompleted.batchInfo.schedulingDelay
      elems <- elements.get(streamUID).map(_.numRecords)
    } computeAndPublish(processingEnd, elems, workDelay, waitDelay)
  }
}

object RateController {
  def isBackPressureEnabled(conf: SparkConf): Boolean =
    conf.getBoolean("spark.streaming.backpressure.enabled", false)
} 
Example 32
Source File: BlockTransferService.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network

import java.io.Closeable
import java.nio.ByteBuffer

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils

private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {

  
  def uploadBlockSync(
      hostname: String,
      port: Int,
      execId: String,
      blockId: BlockId,
      blockData: ManagedBuffer,
      level: StorageLevel,
      classTag: ClassTag[_]): Unit = {
    val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
    ThreadUtils.awaitResult(future, Duration.Inf)
  }
} 
Example 33
Source File: LauncherBackend.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.launcher

import java.net.{InetAddress, Socket}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.launcher.LauncherProtocol._
import org.apache.spark.util.{ThreadUtils, Utils}


  protected def onDisconnected() : Unit = { }

  private def fireStopRequest(): Unit = {
    val thread = LauncherBackend.threadFactory.newThread(new Runnable() {
      override def run(): Unit = Utils.tryLogNonFatalError {
        onStopRequest()
      }
    })
    thread.start()
  }

  private class BackendConnection(s: Socket) extends LauncherConnection(s) {

    override protected def handle(m: Message): Unit = m match {
      case _: Stop =>
        fireStopRequest()

      case _ =>
        throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}")
    }

    override def close(): Unit = {
      try {
        super.close()
      } finally {
        onDisconnected()
        _isConnected = false
      }
    }

  }

}

private object LauncherBackend {

  val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend")

} 
Example 34
Source File: BlockManagerSlaveEndpoint.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.{MapOutputTracker, SparkEnv}


private[storage]
class BlockManagerSlaveEndpoint(
    override val rpcEnv: RpcEnv,
    blockManager: BlockManager,
    mapOutputTracker: MapOutputTracker)
  extends ThreadSafeRpcEndpoint with Logging {

  private val user = Utils.getCurrentUserName

  private val asyncThreadPool =
    ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
  private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)

  // Operations that involve removing blocks may be slow and should be done asynchronously
  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }

    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get(user).shuffleManager.unregisterShuffle(shuffleId)
      }

    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())
  }

  private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
    val future = Future {
      logDebug(actionMessage)
      body
    }
    future.onSuccess { case response =>
      logDebug("Done " + actionMessage + ", response is " + response)
      context.reply(response)
      logDebug("Sent response: " + response + " to " + context.senderAddress)
    }
    future.onFailure { case t: Throwable =>
      logError("Error in " + actionMessage, t)
      context.sendFailure(t)
    }
  }

  override def onStop(): Unit = {
    asyncThreadPool.shutdownNow()
  }
} 
Example 35
Source File: BroadcastHashJoin.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.rdd.RDD
import org.apache.spark.util.ThreadUtils

import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}


@DeveloperApi
case class BroadcastHashJoin(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    buildSide: BuildSide,
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryNode with HashJoin {

  val timeout: Duration = {
    val timeoutValue = sqlContext.conf.broadcastTimeout
    if (timeoutValue < 0) {
      Duration.Inf
    } else {
      timeoutValue.seconds
    }
  }

  override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning

  override def requiredChildDistribution: Seq[Distribution] =
    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

  @transient
  lazy val broadcastFuture = future {
    // Note that we use .execute().collect() because we don't want to convert data to Scala types
    val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
    val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
    sparkContext.broadcast(hashed)
  }(BroadcastHashJoin.broadcastHashJoinExecutionContext)

  protected override def doExecute(): RDD[Row] = {
    val broadcastRelation = Await.result(broadcastFuture, timeout)

    streamedPlan.execute().mapPartitions { streamedIter =>
      hashJoin(streamedIter, broadcastRelation.value)
    }
  }
}

object BroadcastHashJoin {
  private[sql] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
    ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
} 
Example 36
Source File: MesosExternalShuffleService.scala    From drizzle-spark   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.deploy.mesos

import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import scala.collection.JavaConverters._

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.internal.Logging
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat}
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.ThreadUtils


private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager)
  extends ExternalShuffleService(conf, securityManager) {

  protected override def newShuffleBlockHandler(
      conf: TransportConf): ExternalShuffleBlockHandler = {
    val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s")
    new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS)
  }
}

private[spark] object MesosExternalShuffleService extends Logging {

  def main(args: Array[String]): Unit = {
    ExternalShuffleService.main(args,
      (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm))
  }
}