java.io.DataOutputStream Scala Examples

The following examples show how to use java.io.DataOutputStream. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RBackendAuthHandler.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.r

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets.UTF_8

import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private class RBackendAuthHandler(secret: String)
  extends SimpleChannelInboundHandler[Array[Byte]] with Logging {

  override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
    // The R code adds a null terminator to serialized strings, so ignore it here.
    val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
    try {
      require(secret == clientSecret, "Auth secret mismatch.")
      ctx.pipeline().remove(this)
      writeReply("ok", ctx.channel())
    } catch {
      case e: Exception =>
        logInfo("Authentication failure.", e)
        writeReply("err", ctx.channel())
        ctx.close()
    }
  }

  private def writeReply(reply: String, chan: Channel): Unit = {
    val out = new ByteArrayOutputStream()
    SerDe.writeString(new DataOutputStream(out), reply)
    chan.writeAndFlush(out.toByteArray())
  }

} 
Example 2
Source File: MatrixWriteBlockMatrix.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.expr.ir.functions

import java.io.DataOutputStream

import is.hail.HailContext
import is.hail.expr.ir.{ExecuteContext, MatrixValue}
import is.hail.types.{MatrixType, RPrimitive, RTable, TypeWithRequiredness}
import is.hail.types.virtual.{TVoid, Type}
import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, GridPartitioner, WriteBlocksRDD}
import is.hail.utils._
import org.json4s.jackson

case class MatrixWriteBlockMatrix(path: String,
  overwrite: Boolean,
  entryField: String,
  blockSize: Int) extends MatrixToValueFunction {
  def typ(childType: MatrixType): Type = TVoid

  def unionRequiredness(childType: RTable, resultType: TypeWithRequiredness): Unit = ()

  def execute(ctx: ExecuteContext, mv: MatrixValue): Any = {
    val rvd = mv.rvd

    // FIXME
    val partitionCounts = rvd.countPerPartition()

    val fs = ctx.fs

    val partStarts = partitionCounts.scanLeft(0L)(_ + _)
    assert(partStarts.length == rvd.getNumPartitions + 1)

    val nRows = partStarts.last
    val localNCols = mv.nCols

    if (overwrite)
      fs.delete(path, recursive = true)
    else if (fs.exists(path))
      fatal(s"file already exists: $path")

    fs.mkDir(path)

    // write blocks
    fs.mkDir(path + "/parts")
    val gp = GridPartitioner(blockSize, nRows, localNCols)
    val blockPartFiles =
      new WriteBlocksRDD(fs.broadcast, ctx.localTmpdir, path, rvd, partStarts, entryField, gp)
        .collect()

    val blockCount = blockPartFiles.length
    val partFiles = new Array[String](blockCount)
    blockPartFiles.foreach { case (i, f) => partFiles(i) = f }

    // write metadata
    using(new DataOutputStream(fs.create(path + BlockMatrix.metadataRelativePath))) { os =>
      implicit val formats = defaultJSONFormats
      jackson.Serialization.write(
        BlockMatrixMetadata(blockSize, nRows, localNCols, gp.partitionIndexToBlockIndex, partFiles),
        os)
    }

    assert(blockCount == gp.numPartitions)
    info(s"Wrote all $blockCount blocks of $nRows x $localNCols matrix with block size $blockSize.")

    using(fs.create(path + "/_SUCCESS"))(out => ())

    null
  }
} 
Example 3
Source File: ByteArrayOutputFormat.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.io.hadoop

import java.io.DataOutputStream

import org.apache.hadoop.fs._
import org.apache.hadoop.io._
import org.apache.hadoop.mapred._
import org.apache.hadoop.util.Progressable

class ByteArrayOutputFormat extends FileOutputFormat[NullWritable, BytesOnlyWritable] {

  class ByteArrayRecordWriter(out: DataOutputStream) extends RecordWriter[NullWritable, BytesOnlyWritable] {

    def write(key: NullWritable, value: BytesOnlyWritable) {
      if (value != null)
        value.write(out)
    }

    def close(reporter: Reporter) {
      out.close()
    }
  }

  override def getRecordWriter(ignored: FileSystem, job: JobConf,
    name: String, progress: Progressable): RecordWriter[NullWritable, BytesOnlyWritable] = {
    val file: Path = FileOutputFormat.getTaskOutputPath(job, name)
    val fs: FileSystem = file.getFileSystem(job)
    val fileOut: FSDataOutputStream = fs.create(file, progress)
    new ByteArrayRecordWriter(fileOut)
  }
} 
Example 4
Source File: PythonGatewayServer.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  initializeLogIfNecessary(true)

  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 5
Source File: PythonRDDSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(Iterator(
      (null, null),
      ("a".getBytes(StandardCharsets.UTF_8), null),
      (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
  }
} 
Example 6
Source File: MasterWebUISuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 7
Source File: PythonGatewayServer.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 8
Source File: PythonRDDSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 9
Source File: SeqSerde.scala    From affinity   with Apache License 2.0 5 votes vote down vote up
package io.amient.affinity.core.serde.collection

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import akka.actor.ExtendedActorSystem
import com.typesafe.config.Config
import io.amient.affinity.core.serde.{AbstractWrapSerde, Serde, Serdes}

class SeqSerde(serdes: Serdes) extends AbstractWrapSerde(serdes) with Serde[Seq[Any]] {

  def this(system: ExtendedActorSystem) = this(Serde.tools(system))
  def this(config: Config) = this(Serde.tools(config))

  override def identifier: Int = 141

  override def close(): Unit = ()

  override protected def fromBytes(bytes: Array[Byte]): Seq[Any] = {
    val di = new DataInputStream(new ByteArrayInputStream(bytes))
    val numItems = di.readInt()
    val result = ((1 to numItems) map { _ =>
      val len = di.readInt()
      val item = new Array[Byte](len)
      di.read(item)
      fromBinaryWrapped(item)
    }).toList
    di.close()
    result
  }

  override def toBytes(seq: Seq[Any]): Array[Byte] = {
    val os = new ByteArrayOutputStream()
    val d = new DataOutputStream(os)
    d.writeInt(seq.size)
    for (a: Any <- seq) a match {
      case ref: AnyRef =>
        val item = toBinaryWrapped(ref)
        d.writeInt(item.length)
        d.write(item)
    }
    os.close
    os.toByteArray
  }
} 
Example 10
Source File: SetSerde.scala    From affinity   with Apache License 2.0 5 votes vote down vote up
package io.amient.affinity.core.serde.collection

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import akka.actor.ExtendedActorSystem
import com.typesafe.config.Config
import io.amient.affinity.core.serde.{AbstractWrapSerde, Serde, Serdes}

class SetSerde(serdes: Serdes) extends AbstractWrapSerde(serdes) with Serde[Set[Any]] {

  def this(system: ExtendedActorSystem) = this(Serde.tools(system))
  def this(config: Config) = this(Serde.tools(config))

  override def identifier: Int = 142

  override protected def fromBytes(bytes: Array[Byte]): Set[Any] = {
    val di = new DataInputStream(new ByteArrayInputStream(bytes))
    val numItems = di.readInt()
    val result = ((1 to numItems) map { _ =>
      val len = di.readInt()
      val item = new Array[Byte](len)
      di.read(item)
      fromBinaryWrapped(item)
    }).toSet
    di.close()
    result
  }

  override def toBytes(set: Set[Any]): Array[Byte] = {
    val os = new ByteArrayOutputStream()
    val d = new DataOutputStream(os)
    d.writeInt(set.size)
    for (a: Any <- set) a match {
      case ref: AnyRef =>
        val item = toBinaryWrapped(ref)
        d.writeInt(item.length)
        d.write(item)
    }
    os.close
    os.toByteArray
  }

  override def close() = ()
} 
Example 11
Source File: PythonGatewayServer.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    //System.getenv()和System.getProperties()的区别
    //System.getenv() 返回系统环境变量值 设置系统环境变量:当前登录用户主目录下的".bashrc"文件中可以设置系统环境变量
    //System.getProperties() 返回Java进程变量值 通过命令行参数的"-D"选项
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 12
Source File: PythonRDDSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {
  //写大串给worker
  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }
  //很好的处理null
  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 13
Source File: TestingTypedCount.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(expr) - A testing aggregate function resembles COUNT " +
          "but implements ObjectAggregateFunction.")
case class TestingTypedCount(
    child: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0)
  extends TypedImperativeAggregate[TestingTypedCount.State] {

  def this(child: Expression) = this(child, 0, 0)

  override def children: Seq[Expression] = child :: Nil

  override def dataType: DataType = LongType

  override def nullable: Boolean = false

  override def createAggregationBuffer(): State = TestingTypedCount.State(0L)

  override def update(buffer: State, input: InternalRow): State = {
    if (child.eval(input) != null) {
      buffer.count += 1
    }
    buffer
  }

  override def merge(buffer: State, input: State): State = {
    buffer.count += input.count
    buffer
  }

  override def eval(buffer: State): Any = buffer.count

  override def serialize(buffer: State): Array[Byte] = {
    val byteStream = new ByteArrayOutputStream()
    val dataStream = new DataOutputStream(byteStream)
    dataStream.writeLong(buffer.count)
    byteStream.toByteArray
  }

  override def deserialize(storageFormat: Array[Byte]): State = {
    val byteStream = new ByteArrayInputStream(storageFormat)
    val dataStream = new DataInputStream(byteStream)
    TestingTypedCount.State(dataStream.readLong())
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override val prettyName: String = "typed_count"
}

object TestingTypedCount {
  case class State(var count: Long)
} 
Example 14
Source File: SocketAuthHelper.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.security

import java.io.{DataInputStream, DataOutputStream, InputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8

import org.apache.spark.SparkConf
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils


  def authToServer(s: Socket): Unit = {
    writeUtf8(secret, s)

    val reply = readUtf8(s)
    if (reply != "ok") {
      JavaUtils.closeQuietly(s)
      throw new IllegalArgumentException("Authentication failed.")
    }
  }

  protected def readUtf8(s: Socket): String = {
    val din = new DataInputStream(s.getInputStream())
    val len = din.readInt()
    val bytes = new Array[Byte](len)
    din.readFully(bytes)
    new String(bytes, UTF_8)
  }

  protected def writeUtf8(str: String, s: Socket): Unit = {
    val bytes = str.getBytes(UTF_8)
    val dout = new DataOutputStream(s.getOutputStream())
    dout.writeInt(bytes.length)
    dout.write(bytes, 0, bytes.length)
    dout.flush()
  }

} 
Example 15
Source File: PythonGatewayServer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{DataOutputStream, File, FileOutputStream}
import java.net.InetAddress
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files

import py4j.GatewayServer

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  initializeLogIfNecessary(true)

  def main(args: Array[String]): Unit = {
    val secret = Utils.createSecret(new SparkConf())

    // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
    // with the same secret, in case the app needs callbacks from the JVM to the underlying
    // python processes.
    val localhost = InetAddress.getLoopbackAddress()
    val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
      .authToken(secret)
      .javaPort(0)
      .javaAddress(localhost)
      .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
      .build()

    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the connection information back to the python process by writing the
    // information in the requested file. This needs to match the read side in java_gateway.py.
    val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
    val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
      "connection", ".info").toFile()

    val dos = new DataOutputStream(new FileOutputStream(tmpPath))
    dos.writeInt(boundPort)

    val secretBytes = secret.getBytes(UTF_8)
    dos.writeInt(secretBytes.length)
    dos.write(secretBytes, 0, secretBytes.length)
    dos.close()

    if (!tmpPath.renameTo(connectionInfoPath)) {
      logError(s"Unable to write connection information to $connectionInfoPath.")
      System.exit(1)
    }

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 16
Source File: ByteStringSerialization.scala    From scala-commons   with MIT License 5 votes vote down vote up
package com.avsystem.commons
package redis.util

import java.io.{DataInputStream, DataOutputStream}

import akka.util._
import com.avsystem.commons.serialization.{GenCodec, StreamInput, StreamOutput}


object ByteStringSerialization {
  def write[T: GenCodec](value: T): ByteString = {
    val builder = new ByteStringBuilder
    GenCodec.write(new StreamOutput(new DataOutputStream(builder.asOutputStream)), value)
    builder.result()
  }

  def read[T: GenCodec](bytes: ByteString): T =
    GenCodec.read[T](new StreamInput(new DataInputStream(bytes.iterator.asInputStream)))
} 
Example 17
Source File: RAuthHelper.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.r

import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket

import org.apache.spark.SparkConf
import org.apache.spark.security.SocketAuthHelper

private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {

  override protected def readUtf8(s: Socket): String = {
    SerDe.readString(new DataInputStream(s.getInputStream()))
  }

  override protected def writeUtf8(str: String, s: Socket): Unit = {
    val out = s.getOutputStream()
    SerDe.writeString(new DataOutputStream(out), str)
    out.flush()
  }

} 
Example 18
Source File: PythonRDDSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(Iterator(
      (null, null),
      ("a".getBytes(StandardCharsets.UTF_8), null),
      (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
  }
} 
Example 19
Source File: MasterWebUISuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 20
Source File: TaskDescriptionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable.HashMap

import org.apache.spark.SparkFunSuite

class TaskDescriptionSuite extends SparkFunSuite {
  test("encoding and then decoding a TaskDescription results in the same TaskDescription") {
    val originalFiles = new HashMap[String, Long]()
    originalFiles.put("fileUrl1", 1824)
    originalFiles.put("fileUrl2", 2)

    val originalJars = new HashMap[String, Long]()
    originalJars.put("jar1", 3)

    val originalProperties = new Properties()
    originalProperties.put("property1", "18")
    originalProperties.put("property2", "test value")
    // SPARK-19796 -- large property values (like a large job description for a long sql query)
    // can cause problems for DataOutputStream, make sure we handle correctly
    val sb = new StringBuilder()
    (0 to 10000).foreach(_ => sb.append("1234567890"))
    val largeString = sb.toString()
    originalProperties.put("property3", largeString)
    // make sure we've got a good test case
    intercept[UTFDataFormatException] {
      val out = new DataOutputStream(new ByteArrayOutputStream())
      try {
        out.writeUTF(largeString)
      } finally {
        out.close()
      }
    }

    // Create a dummy byte buffer for the task.
    val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))

    val originalTaskDescription = new TaskDescription(
      taskId = 1520589,
      attemptNumber = 2,
      executorId = "testExecutor",
      name = "task for test",
      index = 19,
      originalFiles,
      originalJars,
      originalProperties,
      taskBuffer
    )

    val serializedTaskDescription = TaskDescription.encode(originalTaskDescription)
    val decodedTaskDescription = TaskDescription.decode(serializedTaskDescription)

    // Make sure that all of the fields in the decoded task description match the original.
    assert(decodedTaskDescription.taskId === originalTaskDescription.taskId)
    assert(decodedTaskDescription.attemptNumber === originalTaskDescription.attemptNumber)
    assert(decodedTaskDescription.executorId === originalTaskDescription.executorId)
    assert(decodedTaskDescription.name === originalTaskDescription.name)
    assert(decodedTaskDescription.index === originalTaskDescription.index)
    assert(decodedTaskDescription.addedFiles.equals(originalFiles))
    assert(decodedTaskDescription.addedJars.equals(originalJars))
    assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties))
    assert(decodedTaskDescription.serializedTask.equals(taskBuffer))
  }
} 
Example 21
Source File: BloomFilter.scala    From bloom-filter-scala   with MIT License 5 votes vote down vote up
package bloomfilter.mutable._128bit

import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream}

import bloomfilter.CanGenerate128HashFrom
import bloomfilter.mutable.UnsafeBitArray

import scala.math._

@SerialVersionUID(2L)
class BloomFilter[T] private (val numberOfBits: Long, val numberOfHashes: Int, private val bits: UnsafeBitArray)
    (implicit canGenerateHash: CanGenerate128HashFrom[T]) extends Serializable {

  def this(numberOfBits: Long, numberOfHashes: Int)(implicit canGenerateHash: CanGenerate128HashFrom[T]) {
    this(numberOfBits, numberOfHashes, new UnsafeBitArray(numberOfBits))
  }

  def add(x: T): Unit = {
    val hash = canGenerateHash.generateHash(x)

    var i = 0
    while (i < numberOfHashes) {
      val computedHash = hash._1 + i * hash._2
      bits.set((computedHash & Long.MaxValue) % numberOfBits)
      i += 1
    }
  }

  def mightContain(x: T): Boolean = {
    val hash = canGenerateHash.generateHash(x)

    var i = 0
    while (i < numberOfHashes) {
      val computedHash = hash._1 + i * hash._2
      if (!bits.get((computedHash & Long.MaxValue) % numberOfBits))
        return false
      i += 1
    }
    true
  }

  def expectedFalsePositiveRate(): Double = {
    math.pow(bits.getBitCount.toDouble / numberOfBits, numberOfHashes.toDouble)
  }

  def writeTo(out: OutputStream): Unit = {
    val dout = new DataOutputStream(out)
    dout.writeLong(numberOfBits)
    dout.writeInt(numberOfHashes)
    bits.writeTo(out)
  }

  def approximateElementCount(): Long = {
    val fractionOfBitsSet = bits.getBitCount.toDouble / numberOfBits
    val x = -log1p(-fractionOfBitsSet) * numberOfBits / numberOfHashes
    val z = rint(x)
    if (abs(x - z) == 0.5) {
      (x + Math.copySign(0.5, x)).toLong
    } else {
      z.toLong
    }
  }

  def dispose(): Unit = bits.dispose()

}

object BloomFilter {

  def apply[T](numberOfItems: Long, falsePositiveRate: Double)
      (implicit canGenerateHash: CanGenerate128HashFrom[T]): BloomFilter[T] = {

    val nb = optimalNumberOfBits(numberOfItems, falsePositiveRate)
    val nh = optimalNumberOfHashes(numberOfItems, nb)
    new BloomFilter[T](nb, nh)
  }

  def optimalNumberOfBits(numberOfItems: Long, falsePositiveRate: Double): Long = {
    math.ceil(-1 * numberOfItems * math.log(falsePositiveRate) / math.log(2) / math.log(2)).toLong
  }

  def optimalNumberOfHashes(numberOfItems: Long, numberOfBits: Long): Int = {
    math.ceil(numberOfBits / numberOfItems * math.log(2)).toInt
  }

  def readFrom[T](in: InputStream)(implicit canGenerateHash: CanGenerate128HashFrom[T]): BloomFilter[T] = {
    val din = new DataInputStream(in)
    val numberOfBits = din.readLong()
    val numberOfHashes = din.readInt()
    val bits = new UnsafeBitArray(numberOfBits)
    bits.readFrom(in)
    new BloomFilter[T](numberOfBits, numberOfHashes, bits)
  }

} 
Example 22
Source File: PythonGatewayServer.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 23
Source File: PythonRDDSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 24
Source File: PlyOutputWriter.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.apache.spark.sql.types._
import org.apache.hadoop.mapreduce.{ TaskAttemptID, RecordWriter, TaskAttemptContext, JobContext }
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import java.io.DataOutputStream
import org.apache.spark.sql.sources.OutputWriter
import org.apache.hadoop.io.{ NullWritable, BytesWritable }
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.fs.Path
import java.text.NumberFormat
import org.apache.spark.sql.{ Row, SQLContext, sources }
import fr.ign.spark.iqmulus.RowOutputStream

class PlyOutputWriter(
  name: String,
  context: TaskAttemptContext,
  dataSchema: StructType,
  element: String,
  littleEndian: Boolean
)
    extends OutputWriter {

  private val file = {
    val path = getDefaultWorkFile(s".ply.$element")
    val fs = path.getFileSystem(context.getConfiguration)
    fs.create(path)
  }

  private var count = 0L

  // strip out ids
  private val schema = StructType(dataSchema.filterNot { Seq("fid", "pid") contains _.name })

  private val recordWriter = new RowOutputStream(new DataOutputStream(file), littleEndian, schema, dataSchema)

  def getDefaultWorkFile(extension: String): Path = {
    val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
    val taskAttemptId: TaskAttemptID = context.getTaskAttemptID
    val split = taskAttemptId.getTaskID.getId
    new Path(name, f"$split%05d-$uniqueWriteJobId$extension")
  }

  override def write(row: Row): Unit = {
    recordWriter.write(row)
    count += 1
  }

  override def close(): Unit = {
    recordWriter.close

    // write header
    val path = getDefaultWorkFile(".ply.header")
    val fs = path.getFileSystem(context.getConfiguration)
    val dos = new java.io.DataOutputStream(fs.create(path))
    val header = new PlyHeader(path.toString, littleEndian, Map(element -> ((count, schema))))
    header.write(dos)
    dos.close
  }
} 
Example 25
Source File: LasOutputWriter.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.las

import org.apache.spark.sql.types._
import org.apache.hadoop.mapreduce.{ TaskAttemptID, RecordWriter, TaskAttemptContext }
import java.io.DataOutputStream
import org.apache.spark.sql.sources.OutputWriter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.hadoop.io.{ NullWritable, BytesWritable }
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.fs.Path
import java.text.NumberFormat
import org.apache.spark.sql.{ Row, SQLContext, sources }
import fr.ign.spark.iqmulus.RowOutputStream

class LasOutputWriter(
  name: String,
  context: TaskAttemptContext,
  dataSchema: StructType,
  formatOpt: Option[Byte] = None,
  version: Version = Version(),
  offset: Array[Double] = Array(0F, 0F, 0F),
  scale: Array[Double] = Array(0.01F, 0.01F, 0.01F)
)
    extends OutputWriter {

  private val file = {
    val path = getDefaultWorkFile("/1.pdr")
    val fs = path.getFileSystem(context.getConfiguration)
    fs.create(path)
  }

  private val pmin = Array.fill[Double](3)(Double.PositiveInfinity)
  private val pmax = Array.fill[Double](3)(Double.NegativeInfinity)
  private val countByReturn = Array.fill[Long](15)(0)
  private def count = countByReturn.sum

  private val format = formatOpt.getOrElse(LasHeader.formatFromSchema(dataSchema))

  // todo, extra bytes
  private val schema = LasHeader.schema(format)
  private def header =
    new LasHeader(name, format, count, pmin, pmax, scale, offset, countByReturn)

  private val recordWriter = new RowOutputStream(new DataOutputStream(file), littleEndian = true, schema, dataSchema)

  def getDefaultWorkFile(extension: String): Path = {
    val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
    val taskAttemptId: TaskAttemptID = context.getTaskAttemptID
    val split = taskAttemptId.getTaskID.getId
    new Path(name, f"$split%05d-$uniqueWriteJobId$extension")
  }

  override def write(row: Row): Unit = {
    recordWriter.write(row)

    // gather statistics for the header
    val x = offset(0) + scale(0) * row.getAs[Int]("x").toDouble
    val y = offset(1) + scale(1) * row.getAs[Int]("y").toDouble
    val z = offset(2) + scale(2) * row.getAs[Int]("z").toDouble
    val ret = row.getAs[Byte]("flags") & 0x3
    countByReturn(ret) += 1
    pmin(0) = Math.min(pmin(0), x)
    pmin(1) = Math.min(pmin(1), y)
    pmin(2) = Math.min(pmin(2), z)
    pmax(0) = Math.max(pmax(0), x)
    pmax(1) = Math.max(pmax(1), y)
    pmax(2) = Math.max(pmax(2), z)
  }

  override def close(): Unit = {
    recordWriter.close

    // write header
    val path = getDefaultWorkFile("/0.header")
    val fs = path.getFileSystem(context.getConfiguration)
    val dos = new java.io.DataOutputStream(fs.create(path))
    header.write(dos)
    dos.close

    // copy header and pdf to a final las file (1 per split)
    org.apache.hadoop.fs.FileUtil.copyMerge(
      fs, getDefaultWorkFile("/"),
      fs, getDefaultWorkFile(".las"),
      true, context.getConfiguration, ""
    )
  }
} 
Example 26
Source File: TextLINEModelOutputFormat.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.graph.embedding.line2

import java.io.{DataInputStream, DataOutputStream}

import com.tencent.angel.model.output.format.{ComplexRowFormat, IndexAndElement}
import com.tencent.angel.ps.storage.vector.element.IElement
import org.apache.hadoop.conf.Configuration

class TextLINEModelOutputFormat(conf:Configuration) extends ComplexRowFormat {
  val featSep = conf.get("line.feature.sep", " ")
  val keyValueSep = conf.get("line.keyvalue.sep", ":")
  val modelOrder = conf.get("line.model.order", "2").toInt

  override def load(input: DataInputStream): IndexAndElement = {
    val line = input.readLine()
    val indexAndElement = new IndexAndElement
    val keyValues = line.split(keyValueSep)

    if(featSep.equals(keyValueSep)) {
      indexAndElement.index = keyValues(0).toLong
      val feats = new Array[Float](keyValues.length - 1)
      (1 until keyValues.length).foreach(i => feats(i - 1) = keyValues(i).toFloat)
      indexAndElement.element = new LINENode(feats, null)
    } else {
      indexAndElement.index = keyValues(0).toLong
      val inputFeats = keyValues(1).split(featSep).map(f => f.toFloat)
      if(modelOrder == 1) {
        indexAndElement.element = new LINENode(inputFeats, null)
      } else {
        indexAndElement.element = new LINENode(inputFeats, new Array[Float](inputFeats.length))
      }

    }
    indexAndElement
  }

  override def save(key: Long, value: IElement, output: DataOutputStream): Unit = {
    val sb = new StringBuilder

    // Write node id
    sb.append(key)
    sb.append(keyValueSep)

    // Write feats
    val feats = value.asInstanceOf[LINENode].getInputFeats

    var index = 0
    val len = feats.length
    feats.foreach(f => {
      if(index < len - 1) {
        sb.append(f).append(featSep)
      } else {
        sb.append(f).append("\n")
      }
      index += 1
    })

    output.writeBytes(sb.toString())
  }
} 
Example 27
Source File: ModelToServeStats.scala    From kafka-with-akka-streams-kafka-streams-tutorial   with Apache License 2.0 5 votes vote down vote up
package com.lightbend.scala.modelServer.model

import java.io.DataInputStream
import java.io.DataOutputStream

import scala.util.control.NonFatal


final case class ModelToServeStats(
  name: String,
  description: String,
  since: Long,
  usage: Long = 0,
  duration: Double = 0.0,
  min: Long = 0,
  max: Long = 0) {

  def incrementUsage(execution: Long): ModelToServeStats = copy(
    usage = usage + 1,
    duration = duration + execution,
    min = if (execution < min) execution else min,
    max = if (execution > max) execution else max)
}

object ModelToServeStats {
  val empty = ModelToServeStats("None", "None", 0)

  def apply(m: ModelToServe): ModelToServeStats =
    ModelToServeStats(m.name, m.description, System.currentTimeMillis())

  def readServingInfo(input: DataInputStream) : Option[ModelToServeStats] = {
    input.readLong match {
      case length if length > 0 =>
        try {
          Some(ModelToServeStats(input.readUTF, input.readUTF, input.readLong, input.readLong, input.readDouble, input.readLong, input.readLong))
        } catch {
          case NonFatal(e) =>
            System.out.println("Error Deserializing serving info")
            e.printStackTrace()
            None
        }

      case _ => None
    }
  }

  def writeServingInfo(output: DataOutputStream, servingInfo: ModelToServeStats ): Unit = {
    if(servingInfo == null)
      output.writeLong(0)
    else {
      try {
        output.writeLong(5)
        output.writeUTF(servingInfo.description)
        output.writeUTF(servingInfo.name)
        output.writeLong(servingInfo.since)
        output.writeLong(servingInfo.usage)
        output.writeDouble(servingInfo.duration)
        output.writeLong(servingInfo.min)
        output.writeLong(servingInfo.max)
      } catch {
        case NonFatal(e) =>
          System.out.println("Error Serializing servingInfo")
          e.printStackTrace()
      }
    }
  }
} 
Example 28
Source File: ModelWithDescriptor.scala    From kafka-with-akka-streams-kafka-streams-tutorial   with Apache License 2.0 5 votes vote down vote up
package com.lightbend.scala.modelServer.model

import java.io.{DataInputStream, DataOutputStream}

import com.lightbend.model.modeldescriptor.ModelDescriptor

import scala.collection.Map
import com.lightbend.scala.modelServer.model.PMML.PMMLModel
import com.lightbend.scala.modelServer.model.tensorflow.TensorFlowModel

import scala.util.Try


case class ModelWithDescriptor(model: Model, descriptor: ModelToServe){}

object ModelWithDescriptor {

  private val factories = Map(
    ModelDescriptor.ModelType.PMML.name -> PMMLModel,
    ModelDescriptor.ModelType.TENSORFLOW.name -> TensorFlowModel
  )

  private val factoriesInt = Map(
    ModelDescriptor.ModelType.PMML.index -> PMMLModel,
    ModelDescriptor.ModelType.TENSORFLOW.index -> TensorFlowModel
  )

  def fromModelToServe(descriptor : ModelToServe): Try[ModelWithDescriptor] = Try{
    println(s"New model - $descriptor")
    factories.get(descriptor.modelType.name) match {
      case Some(factory) => ModelWithDescriptor(factory.create(descriptor),descriptor)
      case _ => throw new Throwable("Undefined model type")
    }
  }

  def readModel(input : DataInputStream) : Option[Model] = {
    input.readLong.toInt match{
      case length if length > 0 =>
        val `type` = input.readLong.toInt
        val bytes = new Array[Byte](length)
        input.read(bytes)
        factoriesInt.get(`type`) match {
          case Some(factory) => try {
            Some(factory.restore(bytes))
          }
          catch {
            case t: Throwable =>
              System.out.println("Error Deserializing model")
              t.printStackTrace()
              None
          }
          case _ => None
        }
      case _ => None
    }
  }

  def writeModel(output : DataOutputStream, model: Model) : Unit = {
    if(model == null)
      output.writeLong(0l)
    else {
      try {
        val bytes = model.toBytes
        output.writeLong(bytes.length)
        output.writeLong(model.getType)
        output.write(bytes)
      } catch {
        case t: Throwable =>
          System.out.println("Error Serializing model")
          t.printStackTrace()
      }
    }
  }
} 
Example 29
Source File: ModelStateSerde.scala    From kafka-with-akka-streams-kafka-streams-tutorial   with Apache License 2.0 5 votes vote down vote up
package com.lightbend.scala.kafkastreams.store.store

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util

import com.lightbend.model.modeldescriptor.ModelDescriptor
import com.lightbend.scala.modelServer.model.PMML.PMMLModel
import com.lightbend.scala.modelServer.model.tensorflow.TensorFlowModel
import com.lightbend.scala.modelServer.model.{ModelToServeStats, ModelWithDescriptor}
import com.lightbend.scala.kafkastreams.store.StoreState
import org.apache.kafka.common.serialization.{Deserializer, Serde, Serializer}


class ModelStateSerde extends Serde[StoreState] {

  private val mserializer = new ModelStateSerializer()
  private val mdeserializer = new ModelStateDeserializer()

  override def deserializer() = mdeserializer

  override def serializer() = mserializer

  override def configure(configs: util.Map[String, _], isKey: Boolean) = {}

  override def close() = {}
}

object ModelStateDeserializer {
  val factories = Map(
    ModelDescriptor.ModelType.PMML.index -> PMMLModel,
    ModelDescriptor.ModelType.TENSORFLOW.index -> TensorFlowModel
  )
}

class ModelStateDeserializer extends Deserializer[StoreState] {

  override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = {}

  override def deserialize(topic: String, data: Array[Byte]): StoreState = {
    if(data != null) {
      val input = new DataInputStream(new ByteArrayInputStream(data))
      new StoreState(ModelWithDescriptor.readModel(input), ModelWithDescriptor.readModel(input),
        ModelToServeStats.readServingInfo(input), ModelToServeStats.readServingInfo(input))
    }
    else new StoreState()
  }

  override def close(): Unit = {}

}

class ModelStateSerializer extends Serializer[StoreState] {

  private val bos = new ByteArrayOutputStream()

  override def serialize(topic: String, state: StoreState): Array[Byte] = {
    bos.reset()
    val output = new DataOutputStream(bos)
    ModelWithDescriptor.writeModel(output, state.currentModel.orNull)
    ModelWithDescriptor.writeModel(output, state.newModel.orNull)
    ModelToServeStats.writeServingInfo(output, state.currentState.orNull)
    ModelToServeStats.writeServingInfo(output, state.newState.orNull)
    try {
      output.flush()
      output.close()
    } catch {
      case t: Throwable =>
    }
    bos.toByteArray
  }

  override def close(): Unit = {}

  override def configure(configs: util.Map[String, _], isKey: Boolean) = {}
} 
Example 30
Source File: DefaultRowWriter.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.binary

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.Charset

import ml.combust.mleap.runtime.serialization.{BuiltinFormats, RowWriter}
import ml.combust.mleap.core.types.StructType
import ml.combust.mleap.runtime.frame.Row
import resource._

import scala.util.Try


class DefaultRowWriter(override val schema: StructType) extends RowWriter {
  private val serializers = schema.fields.map(_.dataType).map(ValueSerializer.serializerForDataType)

  override def toBytes(row: Row, charset: Charset = BuiltinFormats.charset): Try[Array[Byte]] = {
    (for(out <- managed(new ByteArrayOutputStream())) yield {
      val dout = new DataOutputStream(out)
      var i = 0
      for(s <- serializers) {
        s.write(row.getRaw(i), dout)
        i = i + 1
      }
      dout.flush()
      out.toByteArray
    }).tried
  }
} 
Example 31
Source File: Serialization.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{DataInputStream, DataOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state
import com.daml.ledger.participant.state.kvutils.export.FileBasedLedgerDataExporter.{
  SubmissionInfo,
  WriteSet
}
import com.google.protobuf.ByteString

object Serialization {
  def serializeEntry(
      submissionInfo: SubmissionInfo,
      writeSet: WriteSet,
      out: DataOutputStream): Unit = {
    serializeSubmissionInfo(submissionInfo, out)
    serializeWriteSet(writeSet, out)
  }

  def readEntry(input: DataInputStream): (SubmissionInfo, WriteSet) = {
    val submissionInfo = readSubmissionInfo(input)
    val writeSet = readWriteSet(input)
    (submissionInfo, writeSet)
  }

  private def serializeSubmissionInfo(
      submissionInfo: SubmissionInfo,
      out: DataOutputStream): Unit = {
    out.writeUTF(submissionInfo.correlationId)
    out.writeInt(submissionInfo.submissionEnvelope.size())
    submissionInfo.submissionEnvelope.writeTo(out)
    out.writeLong(submissionInfo.recordTimeInstant.toEpochMilli)
    out.writeUTF(submissionInfo.participantId)
  }

  private def readSubmissionInfo(input: DataInputStream): SubmissionInfo = {
    val correlationId = input.readUTF()
    val submissionEnvelopeSize = input.readInt()
    val submissionEnvelope = new Array[Byte](submissionEnvelopeSize)
    input.readFully(submissionEnvelope)
    val recordTimeEpochMillis = input.readLong()
    val participantId = input.readUTF()
    SubmissionInfo(
      ByteString.copyFrom(submissionEnvelope),
      correlationId,
      Instant.ofEpochMilli(recordTimeEpochMillis),
      state.v1.ParticipantId.assertFromString(participantId)
    )
  }

  private def serializeWriteSet(writeSet: WriteSet, out: DataOutputStream): Unit = {
    out.writeInt(writeSet.size)
    for ((key, value) <- writeSet.sortBy(_._1.asReadOnlyByteBuffer())) {
      out.writeInt(key.size())
      key.writeTo(out)
      out.writeInt(value.size())
      value.writeTo(out)
    }
  }

  private def readWriteSet(input: DataInputStream): WriteSet = {
    val numKeyValuePairs = input.readInt()
    (1 to numKeyValuePairs).map { _ =>
      val keySize = input.readInt()
      val keyBytes = new Array[Byte](keySize)
      input.readFully(keyBytes)
      val valueSize = input.readInt()
      val valueBytes = new Array[Byte](valueSize)
      input.readFully(valueBytes)
      (ByteString.copyFrom(keyBytes), ByteString.copyFrom(valueBytes))
    }
  }
} 
Example 32
Source File: FileBasedLedgerDataExporter.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.DataOutputStream
import java.time.Instant
import java.util.concurrent.locks.StampedLock

import com.daml.ledger.participant.state.v1.ParticipantId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString

import scala.collection.mutable
import scala.collection.mutable.ListBuffer


class FileBasedLedgerDataExporter(output: DataOutputStream) extends LedgerDataExporter {
  import FileBasedLedgerDataExporter._

  private val outputLock = new StampedLock

  private[export] val correlationIdMapping = mutable.Map.empty[String, String]
  private[export] val inProgressSubmissions = mutable.Map.empty[String, SubmissionInfo]
  private[export] val bufferedKeyValueDataPerCorrelationId =
    mutable.Map.empty[String, mutable.ListBuffer[(Key, Value)]]

  def addSubmission(
      submissionEnvelope: ByteString,
      correlationId: String,
      recordTimeInstant: Instant,
      participantId: ParticipantId): Unit =
    this.synchronized {
      inProgressSubmissions.put(
        correlationId,
        SubmissionInfo(submissionEnvelope, correlationId, recordTimeInstant, participantId))
      ()
    }

  def addParentChild(parentCorrelationId: String, childCorrelationId: String): Unit =
    this.synchronized {
      correlationIdMapping.put(childCorrelationId, parentCorrelationId)
      ()
    }

  def addToWriteSet(correlationId: String, data: Iterable[(Key, Value)]): Unit =
    this.synchronized {
      correlationIdMapping
        .get(correlationId)
        .foreach { parentCorrelationId =>
          val keyValuePairs = bufferedKeyValueDataPerCorrelationId
            .getOrElseUpdate(parentCorrelationId, ListBuffer.empty)
          keyValuePairs.appendAll(data)
          bufferedKeyValueDataPerCorrelationId.put(parentCorrelationId, keyValuePairs)
        }
    }

  def finishedProcessing(correlationId: String): Unit = {
    val (submissionInfo, bufferedData) = this.synchronized {
      (
        inProgressSubmissions.get(correlationId),
        bufferedKeyValueDataPerCorrelationId.get(correlationId))
    }
    submissionInfo.foreach { submission =>
      bufferedData.foreach(writeSubmissionData(submission, _))
      this.synchronized {
        inProgressSubmissions.remove(correlationId)
        bufferedKeyValueDataPerCorrelationId.remove(correlationId)
        correlationIdMapping
          .collect {
            case (key, value) if value == correlationId => key
          }
          .foreach(correlationIdMapping.remove)
      }
    }
  }

  private def writeSubmissionData(
      submissionInfo: SubmissionInfo,
      writeSet: ListBuffer[(Key, Value)]): Unit = {
    val stamp = outputLock.writeLock()
    try {
      Serialization.serializeEntry(submissionInfo, writeSet, output)
      output.flush()
    } finally {
      outputLock.unlock(stamp)
    }
  }
}

object FileBasedLedgerDataExporter {
  case class SubmissionInfo(
      submissionEnvelope: ByteString,
      correlationId: String,
      recordTimeInstant: Instant,
      participantId: ParticipantId)

  type WriteSet = Seq[(Key, Value)]
} 
Example 33
Source File: LedgerDataExporter.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{DataOutputStream, FileOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state.v1.ParticipantId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString
import org.slf4j.LoggerFactory

trait LedgerDataExporter {

  
  def finishedProcessing(correlationId: String): Unit
}

object LedgerDataExporter {
  val EnvironmentVariableName = "KVUTILS_LEDGER_EXPORT"

  private val logger = LoggerFactory.getLogger(this.getClass)

  private lazy val outputStreamMaybe: Option[DataOutputStream] = {
    Option(System.getenv(EnvironmentVariableName))
      .map { filename =>
        logger.info(s"Enabled writing ledger entries to $filename")
        new DataOutputStream(new FileOutputStream(filename))
      }
  }

  private lazy val instance = outputStreamMaybe
    .map(new FileBasedLedgerDataExporter(_))
    .getOrElse(NoopLedgerDataExporter)

  def apply(): LedgerDataExporter = instance
} 
Example 34
Source File: FileBasedLedgerDataExportSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state.v1
import com.google.protobuf.ByteString
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{Matchers, WordSpec}

class FileBasedLedgerDataExportSpec extends WordSpec with Matchers with MockitoSugar {
  // XXX SC remove in Scala 2.13; see notes in ConfSpec
  import scala.collection.GenTraversable, org.scalatest.enablers.Containing
  private[this] implicit def `fixed sig containingNatureOfGenTraversable`[
      E: org.scalactic.Equality,
      TRAV]: Containing[TRAV with GenTraversable[E]] =
    Containing.containingNatureOfGenTraversable[E, GenTraversable]

  "addParentChild" should {
    "add entry to correlation ID mapping" in {
      val instance = new FileBasedLedgerDataExporter(mock[DataOutputStream])
      instance.addParentChild("parent", "child")

      instance.correlationIdMapping should contain("child" -> "parent")
    }
  }

  "addToWriteSet" should {
    "append to existing data" in {
      val instance = new FileBasedLedgerDataExporter(mock[DataOutputStream])
      instance.addParentChild("parent", "child")
      instance.addToWriteSet("child", Seq(keyValuePairOf("a", "b")))
      instance.addToWriteSet("child", Seq(keyValuePairOf("c", "d")))

      instance.bufferedKeyValueDataPerCorrelationId should contain(
        "parent" ->
          Seq(keyValuePairOf("a", "b"), keyValuePairOf("c", "d")))
    }
  }

  "finishedProcessing" should {
    "remove all data such as submission info, write-set and child correlation IDs" in {
      val dataOutputStream = new DataOutputStream(new ByteArrayOutputStream())
      val instance = new FileBasedLedgerDataExporter(dataOutputStream)
      instance.addSubmission(
        ByteString.copyFromUtf8("an envelope"),
        "parent",
        Instant.now(),
        v1.ParticipantId.assertFromString("id"))
      instance.addParentChild("parent", "parent")
      instance.addToWriteSet("parent", Seq(keyValuePairOf("a", "b")))

      instance.finishedProcessing("parent")

      instance.inProgressSubmissions shouldBe empty
      instance.bufferedKeyValueDataPerCorrelationId shouldBe empty
      instance.correlationIdMapping shouldBe empty
    }
  }

  "serialized submission" should {
    "be readable back" in {
      val baos = new ByteArrayOutputStream()
      val dataOutputStream = new DataOutputStream(baos)
      val instance = new FileBasedLedgerDataExporter(dataOutputStream)
      val expectedRecordTimeInstant = Instant.now()
      val expectedParticipantId = v1.ParticipantId.assertFromString("id")
      instance.addSubmission(
        ByteString.copyFromUtf8("an envelope"),
        "parent",
        expectedRecordTimeInstant,
        v1.ParticipantId.assertFromString("id"))
      instance.addParentChild("parent", "parent")
      instance.addToWriteSet("parent", Seq(keyValuePairOf("a", "b")))

      instance.finishedProcessing("parent")

      val dataInputStream = new DataInputStream(new ByteArrayInputStream(baos.toByteArray))
      val (actualSubmissionInfo, actualWriteSet) = Serialization.readEntry(dataInputStream)
      actualSubmissionInfo.submissionEnvelope should be(ByteString.copyFromUtf8("an envelope"))
      actualSubmissionInfo.correlationId should be("parent")
      actualSubmissionInfo.recordTimeInstant should be(expectedRecordTimeInstant)
      actualSubmissionInfo.participantId should be(expectedParticipantId)
      actualWriteSet should be(Seq(keyValuePairOf("a", "b")))
    }
  }

  private def keyValuePairOf(key: String, value: String): (ByteString, ByteString) =
    ByteString.copyFromUtf8(key) -> ByteString.copyFromUtf8(value)
} 
Example 35
Source File: PythonGatewayServer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  initializeLogIfNecessary(true)

  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 36
Source File: PortableDataStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.input

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.JavaConverters._

import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}


  def toArray(): Array[Byte] = {
    val stream = open()
    try {
      ByteStreams.toByteArray(stream)
    } finally {
      Closeables.close(stream, true)
    }
  }

  def getPath(): String = path
} 
Example 37
Source File: PythonRDDSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(Iterator(
      (null, null),
      ("a".getBytes(StandardCharsets.UTF_8), null),
      (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
  }
} 
Example 38
Source File: MasterWebUISuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 39
Source File: SocketService.scala    From HadoopLearning   with MIT License 5 votes vote down vote up
package com.utils

import java.io.DataOutputStream
import java.net.ServerSocket


object SocketService {

  def main(args: Array[String]): Unit = {

    new Thread(new SparkSocket()).start()

  }

  class SparkSocket extends Runnable {
    override def run(): Unit = {
      val server = new ServerSocket(8880)
      while (true) {
        println("等待客户端进行连接....")
        val socket = server.accept()
        //发送给客户端数据
        val stream = new DataOutputStream(socket.getOutputStream())
        println("发送消息了....")
        stream.writeUTF("The Indian government has decided2 to scrap3 a controversial 12% tax on the feminine")
        //五百毫秒停顿一下
        stream.flush()
        stream.close()
      }
    }

  }

} 
Example 40
Source File: WriSer.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.hadoop

import java.io.{ DataInputStream, DataOutputStream, ObjectInputStream, ObjectOutputStream }
import java.io.IOException

import scala.reflect.{ classTag, ClassTag }

import org.apache.hadoop.io.Writable

// Note: we could make this implement InputSplit, but we do not because many input splits do a
// cast to their specific InputSplit, so we do not want to risk it. Further, this currently works
// for any Writable.
case class WriSer[T <: Writable: ClassTag](@transient var get: T) extends Serializable {
  def this() = this(null.asInstanceOf[T])

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream) {
    out.writeObject(classTag[T])
    get.write(new DataOutputStream(out))
  }

  @throws(classOf[IOException])
  @throws(classOf[ClassNotFoundException])
  private def readObject(in: ObjectInputStream) {
    get = in.readObject.asInstanceOf[ClassTag[T]].runtimeClass.newInstance.asInstanceOf[T]
    get.readFields(new DataInputStream(in))
  }
} 
Example 41
Source File: NeuralNetwork.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package Yelp.Trainer

import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.nd4j.linalg.factory.Nd4j
import java.io.File
import org.apache.commons.io.FileUtils
import java.io.{DataInputStream, DataOutputStream, FileInputStream}
import java.nio.file.{Files, Paths}

object NeuralNetwork {  
  def loadNN(NNconfig: String, NNparams: String) = {
    // get neural network config
    val confFromJson: MultiLayerConfiguration = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(NNconfig)))    
     // get neural network parameters 
    val dis: DataInputStream = new DataInputStream(new FileInputStream(NNparams))
    val newParams = Nd4j.read(dis)    
     // creating network object
    val savedNetwork: MultiLayerNetwork = new MultiLayerNetwork(confFromJson)
    savedNetwork.init()
    savedNetwork.setParameters(newParams)    
    savedNetwork
  }
  
  def saveNN(model: MultiLayerNetwork, NNconfig: String, NNparams: String) = {
    // save neural network config
    FileUtils.write(new File(NNconfig), model.getLayerWiseConfigurations().toJson())     
    // save neural network parms
    val dos: DataOutputStream = new DataOutputStream(Files.newOutputStream(Paths.get(NNparams)))
    Nd4j.write(model.params(), dos)
  }  
} 
Example 42
Source File: TestingTypedCount.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(expr) - A testing aggregate function resembles COUNT " +
          "but implements ObjectAggregateFunction.")
case class TestingTypedCount(
    child: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0)
  extends TypedImperativeAggregate[TestingTypedCount.State] {

  def this(child: Expression) = this(child, 0, 0)

  override def children: Seq[Expression] = child :: Nil

  override def dataType: DataType = LongType

  override def nullable: Boolean = false

  override def createAggregationBuffer(): State = TestingTypedCount.State(0L)

  override def update(buffer: State, input: InternalRow): State = {
    if (child.eval(input) != null) {
      buffer.count += 1
    }
    buffer
  }

  override def merge(buffer: State, input: State): State = {
    buffer.count += input.count
    buffer
  }

  override def eval(buffer: State): Any = buffer.count

  override def serialize(buffer: State): Array[Byte] = {
    val byteStream = new ByteArrayOutputStream()
    val dataStream = new DataOutputStream(byteStream)
    dataStream.writeLong(buffer.count)
    byteStream.toByteArray
  }

  override def deserialize(storageFormat: Array[Byte]): State = {
    val byteStream = new ByteArrayInputStream(storageFormat)
    val dataStream = new DataInputStream(byteStream)
    TestingTypedCount.State(dataStream.readLong())
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override val prettyName: String = "typed_count"
}

object TestingTypedCount {
  case class State(var count: Long)
} 
Example 43
Source File: WritableSerializer.scala    From spark-util   with Apache License 2.0 5 votes vote down vote up
package org.hammerlab.hadoop.kryo

import java.io.{ DataInputStream, DataOutputStream }

import com.esotericsoftware.kryo
import com.esotericsoftware.kryo.io.{ Input, Output }
import com.esotericsoftware.kryo.{ Kryo, Serializer }
import org.apache.hadoop.io.Writable


class WritableSerializer[T <: Writable](ctorArgs: Any*)
  extends kryo.Serializer[T] {
  override def read(kryo: Kryo, input: Input, clz: Class[T]): T = {
    val t = clz.newInstance()
    t.readFields(new DataInputStream(input))
    t
  }

  override def write(kryo: Kryo, output: Output, t: T): Unit = {
    t.write(new DataOutputStream(output))
  }
} 
Example 44
Source File: DefaultFrameWriter.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.binary

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.Charset

import ml.combust.mleap.json.JsonSupport._
import ml.combust.mleap.runtime.frame.LeapFrame
import ml.combust.mleap.runtime.serialization.{BuiltinFormats, FrameWriter}
import spray.json._
import resource._

import scala.util.Try


class DefaultFrameWriter[LF <: LeapFrame[LF]](frame: LF) extends FrameWriter {
  override def toBytes(charset: Charset = BuiltinFormats.charset): Try[Array[Byte]] = {
    (for(out <- managed(new ByteArrayOutputStream())) yield {
      val serializers = frame.schema.fields.map(_.dataType).map(ValueSerializer.serializerForDataType)
      val dout = new DataOutputStream(out)
      val schemaBytes = frame.schema.toJson.prettyPrint.getBytes(BuiltinFormats.charset)
      val rows = frame.collect()
      dout.writeInt(schemaBytes.length)
      dout.write(schemaBytes)
      dout.writeInt(rows.size)

      for(row <- rows) {
        var i = 0
        for(s <- serializers) {
          s.write(row.getRaw(i), dout)
          i = i + 1
        }
      }

      dout.flush()
      out.toByteArray
    }).tried
  }
} 
Example 45
Source File: Debug.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils

import java.io.{DataOutputStream, FileOutputStream}

import com.daml.ledger.participant.state.kvutils.DamlKvutils._
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._


  def dumpLedgerEntry(
      submission: DamlSubmission,
      participantId: String,
      entryId: DamlLogEntryId,
      logEntry: DamlLogEntry,
      outputState: Map[DamlStateKey, DamlStateValue]): Unit =
    optLedgerDumpStream.foreach { outs =>
      val dumpEntry = DamlKvutils.LedgerDumpEntry.newBuilder
        .setSubmission(Envelope.enclose(submission))
        .setEntryId(entryId)
        .setParticipantId(participantId)
        .setLogEntry(Envelope.enclose(logEntry))
        .addAllOutputState(
          outputState.map {
            case (k, v) =>
              DamlKvutils.LedgerDumpEntry.StatePair.newBuilder
                .setStateKey(k)
                .setStateValue(Envelope.enclose(v))
                .build
          }.asJava
        )
        .build

      // Messages are delimited by a header containing the message size as int32
      outs.writeInt(dumpEntry.getSerializedSize)
      dumpEntry.writeTo(outs)
      outs.flush()
    }

} 
Example 46
Source File: HttpClientTestSupport.scala    From wix-http-testkit   with MIT License 5 votes vote down vote up
package com.wix.e2e.http.drivers

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}

import akka.http.scaladsl.model.HttpMethods.GET
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.`Transfer-Encoding`
import akka.stream.scaladsl.Source
import com.wix.e2e.http.client.extractors._
import com.wix.e2e.http.info.HttpTestkitVersion
import com.wix.e2e.http.matchers.{RequestMatcher, ResponseMatcher}
import com.wix.e2e.http.{BaseUri, HttpRequest, RequestHandler}
import com.wix.test.random._

import scala.collection.immutable
import scala.collection.mutable.ListBuffer

trait HttpClientTestSupport {
  val parameter = randomStrPair
  val header = randomStrPair
  val formData = randomStrPair
  val userAgent = randomStr
  val cookie = randomStrPair
  val path = s"$randomStr/$randomStr"
  val anotherPath = s"$randomStr/$randomStr"
  val someObject = SomeCaseClass(randomStr, randomInt)

  val somePort = randomPort
  val content = randomStr
  val anotherContent = randomStr

  val requestData = ListBuffer.empty[String]


  val bigResponse = 1024 * 1024

  def issueChunkedPostRequestWith(content: String, toPath: String)(implicit baseUri: BaseUri) = {
    val serverUrl = new URL(s"http://localhost:${baseUri.port}/$toPath")
    val conn = serverUrl.openConnection.asInstanceOf[HttpURLConnection]
    conn.setRequestMethod("POST")
    conn.setRequestProperty("Content-Type", "text/plain")
    conn.setChunkedStreamingMode(0)
    conn.setDoOutput(true)
    conn.setDoInput(true)
    conn.setUseCaches(false)
    conn.connect()

    val out = new DataOutputStream(conn.getOutputStream)
    out.writeBytes(content)
    out.flush()
    out.close()
    conn.disconnect()
  }
}

object HttpClientTestResponseHandlers {
  def handlerFor(path: String, returnsBody: String): RequestHandler = {
    case r: HttpRequest if r.uri.path.toString.endsWith(path) => HttpResponse(entity = returnsBody)
  }

  def unmarshallingAndStoringHandlerFor(path: String, storeTo: ListBuffer[String]): RequestHandler = {
    case r: HttpRequest if r.uri.path.toString.endsWith(path) =>
      storeTo.append( r.extractAsString )
      HttpResponse()
  }

  def bigResponseWith(size: Int): RequestHandler = {
    case HttpRequest(GET, uri, _, _, _) if uri.path.toString().contains("big-response") =>
      HttpResponse(entity = HttpEntity(randomStrWith(size)))
  }

  def chunkedResponseFor(path: String): RequestHandler = {
    case r: HttpRequest if r.uri.path.toString.endsWith(path) =>
      HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain(UTF-8)`, Source.single(randomStr)))
  }

  def alwaysRespondWith(transferEncoding: TransferEncoding, toPath: String): RequestHandler = {
    case r: HttpRequest if r.uri.path.toString.endsWith(toPath) =>
      HttpResponse().withHeaders(immutable.Seq(`Transfer-Encoding`(transferEncoding)))
  }

  val slowRespondingServer: RequestHandler = { case _ => Thread.sleep(500); HttpResponse() }
}

case class SomeCaseClass(s: String, i: Int)

object HttpClientMatchers {
  import com.wix.e2e.http.matchers.RequestMatchers._

  def haveClientHttpTestkitUserAgentWithLibraryVersion: RequestMatcher =
    haveAnyHeadersOf("User-Agent" -> s"client-http-testkit/$HttpTestkitVersion")
}

object HttpServerMatchers {
  import com.wix.e2e.http.matchers.ResponseMatchers._

  def haveServerHttpTestkitHeaderWithLibraryVersion: ResponseMatcher =
    haveAnyHeadersOf("Server" -> s"server-http-testkit/$HttpTestkitVersion")
} 
Example 47
Source File: OapBitmapWrappedFiberCacheSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.utils

import java.io.{ByteArrayOutputStream, DataOutputStream, FileOutputStream}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.OapException
import org.apache.spark.sql.execution.datasources.oap.filecache.{BitmapFiberId, FiberCache}
import org.apache.spark.sql.oap.OapRuntime
import org.apache.spark.sql.test.oap.SharedOapContext
import org.apache.spark.util.Utils

// Below are used to test the functionality of OapBitmapWrappedFiberCache class.
class OapBitmapWrappedFiberCacheSuite
  extends QueryTest with SharedOapContext {

  private def loadRbFile(fin: FSDataInputStream, offset: Long, size: Int): FiberCache =
    OapRuntime.getOrCreate.fiberCacheManager.toIndexFiberCache(fin, offset, size)

  test("test the functionality of OapBitmapWrappedFiberCache class") {
    val CHUNK_SIZE = 1 << 16
    val dataForRunChunk = (1 to 9).toSeq
    val dataForArrayChunk = Seq(1, 3, 5, 7, 9)
    val dataForBitmapChunk = (1 to 10000).filter(_ % 2 == 1)
    val dataCombination =
      dataForBitmapChunk ++ dataForArrayChunk ++ dataForRunChunk
    val dataArray =
      Array(dataForRunChunk, dataForArrayChunk, dataForBitmapChunk, dataCombination)
    dataArray.foreach(dataIdx => {
      val dir = Utils.createTempDir()
      val rb = new RoaringBitmap()
      dataIdx.foreach(rb.add)
      val rbFile = dir.getAbsolutePath + "rb.bin"
      rb.runOptimize()
      val rbFos = new FileOutputStream(rbFile)
      val rbBos = new ByteArrayOutputStream()
      val rbDos = new DataOutputStream(rbBos)
      rb.serialize(rbDos)
      rbBos.writeTo(rbFos)
      rbBos.close()
      rbDos.close()
      rbFos.close()
      val rbPath = new Path(rbFile.toString)
      val conf = new Configuration()
      val fin = rbPath.getFileSystem(conf).open(rbPath)
      val rbFileSize = rbPath.getFileSystem(conf).getFileStatus(rbPath).getLen
      val rbFiber = BitmapFiberId(
        () => loadRbFile(fin, 0L, rbFileSize.toInt), rbPath.toString, 0, 0)
      val rbWfc = new OapBitmapWrappedFiberCache(
        OapRuntime.getOrCreate.fiberCacheManager.get(rbFiber))
      rbWfc.init
      val chunkLength = rbWfc.getTotalChunkLength
      val length = dataIdx.size / CHUNK_SIZE
      assert(chunkLength == (length + 1))
      val chunkKeys = rbWfc.getChunkKeys
      assert(chunkKeys(0).toInt == 0)
      rbWfc.setOffset(0)
      val chunk = rbWfc.getIteratorForChunk(0)
      chunk match {
        case RunChunkIterator(rbWfc) => assert(chunk == RunChunkIterator(rbWfc))
        case ArrayChunkIterator(rbWfc, 0) => assert(chunk == ArrayChunkIterator(rbWfc, 0))
        case BitmapChunkIterator(rbWfc) => assert(chunk == BitmapChunkIterator(rbWfc))
        case _ => throw new OapException("unexpected chunk in OapBitmapWrappedFiberCache.")
      }
      rbWfc.release
      fin.close
      dir.delete
    })
  }
} 
Example 48
Source File: HadoopConfig.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.fsio

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.language.implicitConversions

import org.apache.hadoop.conf.Configuration

import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.util.Constants._

class HadoopConfig(config: UserConfig) {

  def withHadoopConf(conf: Configuration): UserConfig = {
    config.withBytes(HADOOP_CONF, serializeHadoopConf(conf))
  }

  def hadoopConf: Configuration = deserializeHadoopConf(config.getBytes(HADOOP_CONF).get)

  private def serializeHadoopConf(conf: Configuration): Array[Byte] = {
    val out = new ByteArrayOutputStream()
    val dataOut = new DataOutputStream(out)
    conf.write(dataOut)
    dataOut.close()
    out.toByteArray
  }

  private def deserializeHadoopConf(bytes: Array[Byte]): Configuration = {
    val in = new ByteArrayInputStream(bytes)
    val dataIn = new DataInputStream(in)
    val result = new Configuration()
    result.readFields(dataIn)
    dataIn.close()
    result
  }
}

object HadoopConfig {
  def empty: HadoopConfig = new HadoopConfig(UserConfig.empty)
  def apply(config: UserConfig): HadoopConfig = new HadoopConfig(config)

  implicit def userConfigToHadoopConfig(userConf: UserConfig): HadoopConfig = {
    HadoopConfig(userConf)
  }
} 
Example 49
Source File: CRFFromParsedFile.scala    From CRF-Spark   with Apache License 2.0 5 votes vote down vote up
package com.intel.ssg.bdt.nlp

import java.io.{DataOutputStream, DataInputStream, FileInputStream, FileOutputStream}

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object CRFFromParsedFile {

  def main(args: Array[String]) {
    val templateFile = "src/test/resources/chunking/template"
    val trainFile = "src/test/resources/chunking/serialized/train.data"
    val testFile = "src/test/resources/chunking/serialized/test.data"

    val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}")
    val sc = new SparkContext(conf)

    val templates: Array[String] = scala.io.Source.fromFile(templateFile).getLines().filter(_.nonEmpty).toArray
    val trainRDD: RDD[Sequence] = sc.textFile(trainFile).filter(_.nonEmpty).map(Sequence.deSerializer)

    val model: CRFModel = CRF.train(templates, trainRDD, 0.25, 1, 100, 1E-3, "L1")

    val testRDD: RDD[Sequence] = sc.textFile(testFile).filter(_.nonEmpty).map(Sequence.deSerializer)

    
    val results: RDD[Sequence] = model.setNBest(10)
      .setVerboseMode(VerboseLevel1)
      .predict(testRDD)

    val score = results
      .zipWithIndex()
      .map(_.swap)
      .join(testRDD.zipWithIndex().map(_.swap))
      .map(_._2)
      .map(x => x._1.compare(x._2))
      .reduce(_ + _)
    val total = testRDD.map(_.toArray.length).reduce(_ + _)
    println(s"Prediction Accuracy: $score / $total")

    sc.stop()
  }
} 
Example 50
Source File: PythonGatewayServer.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  initializeLogIfNecessary(true)

  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 51
Source File: PythonRDDSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets

import org.apache.spark.SparkFunSuite

class PythonRDDSuite extends SparkFunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(Iterator(
      (null, null),
      ("a".getBytes(StandardCharsets.UTF_8), null),
      (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
  }
} 
Example 52
Source File: MasterWebUISuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master.ui

import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Date

import scala.collection.mutable.HashMap

import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver}
import org.apache.spark.deploy.DeployTestUtils._
import org.apache.spark.deploy.master._
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}


class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {

  val conf = new SparkConf
  val securityMgr = new SecurityManager(conf)
  val rpcEnv = mock(classOf[RpcEnv])
  val master = mock(classOf[Master])
  val masterEndpointRef = mock(classOf[RpcEndpointRef])
  when(master.securityMgr).thenReturn(securityMgr)
  when(master.conf).thenReturn(conf)
  when(master.rpcEnv).thenReturn(rpcEnv)
  when(master.self).thenReturn(masterEndpointRef)
  val masterWebUI = new MasterWebUI(master, 0)

  override def beforeAll() {
    super.beforeAll()
    masterWebUI.bind()
  }

  override def afterAll() {
    masterWebUI.stop()
    super.afterAll()
  }

  test("kill application") {
    val appDesc = createAppDesc()
    // use new start date so it isn't filtered by UI
    val activeApp = new ApplicationInfo(
      new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue)

    when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp)))

    val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/"
    val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify the master was called to remove the active app
    verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED)
  }

  test("kill driver") {
    val activeDriverId = "driver-0"
    val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/"
    val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true")))
    val conn = sendHttpRequest(url, "POST", body)
    conn.getResponseCode

    // Verify that master was asked to kill driver with the correct id
    verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId))
  }

  private def convPostDataToString(data: Map[String, String]): String = {
    (for ((name, value) <- data) yield s"$name=$value").mkString("&")
  }

  
  private def sendHttpRequest(
      url: String,
      method: String,
      body: String = ""): HttpURLConnection = {
    val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection]
    conn.setRequestMethod(method)
    if (body.nonEmpty) {
      conn.setDoOutput(true)
      conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded")
      conn.setRequestProperty("Content-Length", Integer.toString(body.length))
      val out = new DataOutputStream(conn.getOutputStream)
      out.write(body.getBytes(StandardCharsets.UTF_8))
      out.close()
    }
    conn
  }
} 
Example 53
Source File: PythonGatewayServer.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils


private[spark] object PythonGatewayServer extends Logging {
  def main(args: Array[String]): Unit = Utils.tryOrExit {
    // Start a GatewayServer on an ephemeral port
    val gatewayServer: GatewayServer = new GatewayServer(null, 0)
    gatewayServer.start()
    val boundPort: Int = gatewayServer.getListeningPort
    if (boundPort == -1) {
      logError("GatewayServer failed to bind; exiting")
      System.exit(1)
    } else {
      logDebug(s"Started PythonGatewayServer on port $boundPort")
    }

    // Communicate the bound port back to the caller via the caller-specified callback port
    val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
    val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
    logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
    val callbackSocket = new Socket(callbackHost, callbackPort)
    val dos = new DataOutputStream(callbackSocket.getOutputStream)
    dos.writeInt(boundPort)
    dos.close()
    callbackSocket.close()

    // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
    while (System.in.read() != -1) {
      // Do nothing
    }
    logDebug("Exiting due to broken pipe from Python driver")
    System.exit(0)
  }
} 
Example 54
Source File: PythonRDDSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.api.python

import java.io.{ByteArrayOutputStream, DataOutputStream}

import org.scalatest.FunSuite

class PythonRDDSuite extends FunSuite {

  test("Writing large strings to the worker") {
    val input: List[String] = List("a"*100000)
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    PythonRDD.writeIteratorToStream(input.iterator, buffer)
  }

  test("Handle nulls gracefully") {
    val buffer = new DataOutputStream(new ByteArrayOutputStream)
    // Should not have NPE when write an Iterator with null in it
    // The correctness will be tested in Python
    PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
    PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
    PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
    PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
    PythonRDD.writeIteratorToStream(
      Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
  }
} 
Example 55
Source File: CarbonCustomBlockDistributionTest.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.spark.testsuite.blockprune

import java.io.DataOutputStream

import org.apache.spark.sql.Row
import org.scalatest.BeforeAndAfterAll
import org.apache.carbondata.core.constants.CarbonCommonConstants
import org.apache.carbondata.core.datastore.impl.FileFactory
import org.apache.carbondata.core.util.CarbonProperties
import org.apache.spark.sql.test.util.QueryTest


class CarbonCustomBlockDistributionTest extends QueryTest with BeforeAndAfterAll {
  val outputPath = s"$resourcesPath/block_prune_test.csv"
  override def beforeAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_CUSTOM_BLOCK_DISTRIBUTION, "true")
    // Since the data needed for block prune is big, need to create a temp data file
    val testData: Array[String]= new Array[String](3)
    testData(0) = "a"
    testData(1) = "b"
    testData(2) = "c"
    var writer: DataOutputStream = null
    try {
      val file = FileFactory.getCarbonFile(outputPath)
      if (!file.exists()) {
        file.createNewFile()
      }
      writer = FileFactory.getDataOutputStream(outputPath)
      for (i <- 0 to 2) {
        for (j <- 0 to 240000) {
          writer.writeBytes(testData(i) + "," + j + "\n")
        }
      }
    } catch {
      case ex: Exception =>
        LOGGER.error("Build test file for block prune failed", ex)
    } finally {
      if (writer != null) {
        try {
          writer.close()
        } catch {
          case ex: Exception =>
            LOGGER.error("Close output stream catching exception", ex)
        }
      }
    }

    sql("DROP TABLE IF EXISTS blockprune")
  }

  test("test block prune query") {
    sql(
      """
        CREATE TABLE IF NOT EXISTS blockprune (name string, id int)
        STORED AS carbondata
      """)
    sql(
        s"LOAD DATA LOCAL INPATH '$outputPath' INTO table blockprune options('FILEHEADER'='name,id')"
      )
    // data is in all 7 blocks
    checkAnswer(
      sql(
        """
          select name,count(name) as amount from blockprune
          where name='c' or name='b' or name='a' group by name
        """),
      Seq(Row("a", 240001), Row("b", 240001), Row("c", 240001)))

    // data only in middle 3/4/5 blocks
    checkAnswer(
      sql(
        """
          select name,count(name) as amount from blockprune
          where name='b' group by name
        """),
      Seq(Row("b", 240001)))
  }

  override def afterAll {
    CarbonProperties.getInstance()
      .addProperty(CarbonCommonConstants.CARBON_CUSTOM_BLOCK_DISTRIBUTION,
        CarbonCommonConstants.CARBON_CUSTOM_BLOCK_DISTRIBUTION_DEFAULT)
    // delete the temp data file
    try {
      val file = FileFactory.getCarbonFile(outputPath)
      if (file.exists()) {
        file.delete()
      }
    } catch {
      case ex: Exception =>
        LOGGER.error("Delete temp test data file for block prune catching exception", ex)
    }
    sql("DROP TABLE IF EXISTS blockprune")
  }

} 
Example 56
Source File: AmqpFieldValueSpec.scala    From fs2-rabbit   with Apache License 2.0 5 votes vote down vote up
package dev.profunktor.fs2rabbit

import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream}
import java.time.Instant

import com.rabbitmq.client.impl.{ValueReader, ValueWriter}
import dev.profunktor.fs2rabbit.model.AmqpFieldValue._
import dev.profunktor.fs2rabbit.model.{AmqpFieldValue, ShortString}
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers

class AmqpFieldValueSpec extends AnyFlatSpecLike with Matchers with AmqpPropertiesArbitraries {

  it should "convert from and to Java primitive header values" in {
    val intVal    = IntVal(1)
    val longVal   = LongVal(2L)
    val stringVal = StringVal("hey")
    val arrayVal  = ArrayVal(Vector(IntVal(3), IntVal(2), IntVal(1)))

    AmqpFieldValue.unsafeFrom(intVal.toValueWriterCompatibleJava) should be(intVal)
    AmqpFieldValue.unsafeFrom(longVal.toValueWriterCompatibleJava) should be(longVal)
    AmqpFieldValue.unsafeFrom(stringVal.toValueWriterCompatibleJava) should be(stringVal)
    AmqpFieldValue.unsafeFrom("fs2") should be(StringVal("fs2"))
    AmqpFieldValue.unsafeFrom(arrayVal.toValueWriterCompatibleJava) should be(arrayVal)
  }
  it should "preserve the same value after a round-trip through impure and from" in {
    forAll { amqpHeaderVal: AmqpFieldValue =>
      AmqpFieldValue.unsafeFrom(amqpHeaderVal.toValueWriterCompatibleJava) == amqpHeaderVal
    }
  }

  it should "preserve the same values after a round-trip through the Java ValueReader and ValueWriter" in {
    forAll(assertThatValueIsPreservedThroughJavaWriteAndRead _)
  }

  it should "preserve a specific StringVal that previously failed after a round-trip through the Java ValueReader and ValueWriter" in {
    assertThatValueIsPreservedThroughJavaWriteAndRead(StringVal("kyvmqzlbjivLqQFukljghxdowkcmjklgSeybdy"))
  }

  it should "preserve a specific DateVal created from an Instant that has millisecond accuracy after a round-trip through the Java ValueReader and ValueWriter" in {
    val instant   = Instant.parse("4000-11-03T20:17:29.57Z")
    val myDateVal = TimestampVal.from(instant)
    assertThatValueIsPreservedThroughJavaWriteAndRead(myDateVal)
  }

  "DecimalVal" should "reject a BigDecimal of an unscaled value with 33 bits..." in {
    DecimalVal.from(BigDecimal(Int.MaxValue) + BigDecimal(1)) should be(None)
  }
  it should "reject a BigDecimal with a scale over octet size" in {
    DecimalVal.from(new java.math.BigDecimal(java.math.BigInteger.valueOf(12345L), 1000)) should be(None)
  }

  // We need to wrap things in a dummy table because the method that would be
  // great to test with ValueReader, readFieldValue, is private, and so we
  // have to call the next best thing, readTable.
  private def wrapInDummyTable(value: AmqpFieldValue): TableVal =
    TableVal(Map(ShortString.unsafeFrom("dummyKey") -> value))

  private def createWriterFromQueue(outputResults: collection.mutable.Queue[Byte]): ValueWriter =
    new ValueWriter({
      new DataOutputStream({
        new OutputStream {
          override def write(b: Int): Unit =
            outputResults.enqueue(b.toByte)
        }
      })
    })

  private def createReaderFromQueue(input: collection.mutable.Queue[Byte]): ValueReader = {
    val inputStream = new InputStream {
      override def read(): Int =
        try {
          val result = input.dequeue()
          // A signed -> unsigned conversion because bytes by default are
          // converted into signed ints, which is bad when the API of read
          // states that negative numbers indicate EOF...
          0Xff & result.toInt
        } catch {
          case _: NoSuchElementException => -1
        }

      override def available(): Int = {
        val result = input.size
        result
      }
    }
    new ValueReader(new DataInputStream(inputStream))
  }

  private def assertThatValueIsPreservedThroughJavaWriteAndRead(amqpHeaderVal: AmqpFieldValue): Assertion = {
    val outputResultsAsTable = collection.mutable.Queue.empty[Byte]
    val tableWriter          = createWriterFromQueue(outputResultsAsTable)
    tableWriter.writeTable(wrapInDummyTable(amqpHeaderVal).toValueWriterCompatibleJava)

    val reader    = createReaderFromQueue(outputResultsAsTable)
    val readValue = reader.readTable()
    AmqpFieldValue.unsafeFrom(readValue) should be(wrapInDummyTable(amqpHeaderVal))
  }
} 
Example 57
Source File: StreamInputOutputBenchmark.scala    From scala-commons   with MIT License 5 votes vote down vote up
package com.avsystem.commons
package ser

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import com.avsystem.commons.serialization.{GenCodec, StreamInput, StreamOutput}
import org.openjdk.jmh.annotations.{Benchmark, BenchmarkMode, Fork, Measurement, Mode, Scope, State, Warmup}
import org.openjdk.jmh.infra.Blackhole


case class Toplevel(int: Int, nested: Nested, str: String)
case class Nested(list: List[Int], int: Int)

object Toplevel {
  implicit val nestedCodec: GenCodec[Nested] = GenCodec.materialize[Nested]
  implicit val codec: GenCodec[Toplevel] = GenCodec.materialize[Toplevel]
}

@Warmup(iterations = 10)
@Measurement(iterations = 20)
@Fork(1)
@BenchmarkMode(Array(Mode.Throughput))
@State(Scope.Thread)
class StreamInputOutputBenchmark {

  val something = Toplevel(35, Nested(List(121, 122, 123, 124, 125, 126), 53), "lol")

  val inputArray: Array[Byte] = {
    val os = new ByteArrayOutputStream()

    GenCodec.write(new StreamOutput(new DataOutputStream(os)), something)
    os.toByteArray
  }

  @Benchmark
  def testEncode(bh: Blackhole): Unit = {
    val os = new ByteArrayOutputStream(inputArray.length)
    val output = new StreamOutput(new DataOutputStream(os))
    GenCodec.write(output, something)
    bh.consume(os.toByteArray)
  }

  @Benchmark
  def testDecode(bh: Blackhole): Unit = {
    val is = new DataInputStream(new ByteArrayInputStream(inputArray))
    val input = new StreamInput(is)
    bh.consume(GenCodec.read[Toplevel](input))
  }

  @Benchmark
  def testEncodeRaw(bh: Blackhole): Unit = {
    val os = new ByteArrayOutputStream(inputArray.length)
    val output = new StreamOutput(new DataOutputStream(os))
    val toplevelOutput = output.writeObject()
    toplevelOutput.writeField("int").writeSimple().writeInt(35)
    val nestedOutput = toplevelOutput.writeField("nested").writeObject()
    val listOutput = nestedOutput.writeField("list").writeList()
    listOutput.writeElement().writeSimple().writeInt(121)
    listOutput.writeElement().writeSimple().writeInt(122)
    listOutput.writeElement().writeSimple().writeInt(123)
    listOutput.writeElement().writeSimple().writeInt(124)
    listOutput.writeElement().writeSimple().writeInt(125)
    listOutput.writeElement().writeSimple().writeInt(126)
    listOutput.finish()
    nestedOutput.writeField("int").writeSimple().writeInt(53)
    nestedOutput.finish()
    toplevelOutput.writeField("str").writeSimple().writeString("lol")
    toplevelOutput.finish()
    bh.consume(os.toByteArray)
  }

  @Benchmark
  def testDecodeRaw(bh: Blackhole): Unit = {
    val is = new DataInputStream(new ByteArrayInputStream(inputArray))
    val input = new StreamInput(is)
    val objInput = input.readObject()
    val intField = objInput.nextField().readSimple().readInt()
    val nestedInput = objInput.nextField().readObject()
    val listInput = nestedInput.nextField().readList()
    val listNested = List(
      listInput.nextElement().readSimple().readInt(),
      listInput.nextElement().readSimple().readInt(),
      listInput.nextElement().readSimple().readInt(),
      listInput.nextElement().readSimple().readInt(),
      listInput.nextElement().readSimple().readInt(),
      listInput.nextElement().readSimple().readInt()
    )
    listInput.hasNext
    val intNested = nestedInput.nextField().readSimple().readInt()
    nestedInput.hasNext
    val strField = objInput.nextField().readSimple().readString()
    objInput.hasNext
    bh.consume(Toplevel(intField, Nested(listNested, intNested), strField))
  }
} 
Example 58
Source File: StreamGenCodecTest.scala    From scala-commons   with MIT License 5 votes vote down vote up
package com.avsystem.commons
package serialization

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

class StreamGenCodecTest extends GenCodecRoundtripTest {
  type Raw = Array[Byte]

  def writeToOutput(write: Output => Unit): Array[Byte] = {
    val baos = new ByteArrayOutputStream
    write(new StreamOutput(new DataOutputStream(baos)))
    baos.toByteArray
  }

  def createInput(raw: Array[Byte]): Input =
    new StreamInput(new DataInputStream(new ByteArrayInputStream(raw)))
}