scala.language.existentials Scala Examples

The following examples show how to use scala.language.existentials. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: BlockedRDD.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.sparkextras

import is.hail.utils._
import org.apache.spark.rdd.RDD
import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext}

import scala.language.existentials
import scala.reflect.ClassTag

case class BlockedRDDPartition(@transient rdd: RDD[_],
  index: Int,
  first: Int,
  last: Int) extends Partition {
  require(first <= last)

  val parentPartitions: Array[Partition] = range.map(rdd.partitions).toArray

  def range: Range = first to last
}

class BlockedRDD[T](@transient var prev: RDD[T],
  @transient val partFirst: Array[Int],
  @transient val partLast: Array[Int]
)(implicit tct: ClassTag[T]) extends RDD[T](prev.sparkContext, Nil) {
  assert(partFirst.length == partLast.length)

  override def getPartitions: Array[Partition] = {
    Array.tabulate[Partition](partFirst.length)(i =>
      BlockedRDDPartition(prev, i, partFirst(i), partLast(i)))
  }

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val parent = dependencies.head.rdd.asInstanceOf[RDD[T]]
    split.asInstanceOf[BlockedRDDPartition].parentPartitions.iterator.flatMap(p =>
      parent.iterator(p, context))
  }

  override def getDependencies: Seq[Dependency[_]] = {
    FastSeq(new NarrowDependency(prev) {
      def getParents(id: Int): Seq[Int] =
        partitions(id).asInstanceOf[BlockedRDDPartition].range
    })
  }

  override def clearDependencies() {
    super.clearDependencies()
    prev = null
  }

  override def getPreferredLocations(partition: Partition): Seq[String] = {
    val prevPartitions = prev.partitions
    val range = partition.asInstanceOf[BlockedRDDPartition].range

    val locationAvail = range.flatMap(i =>
      prev.preferredLocations(prevPartitions(i)))
      .groupBy(identity)
      .mapValues(_.length)

    if (locationAvail.isEmpty)
      return FastSeq.empty[String]

    val m = locationAvail.values.max
    locationAvail.filter(_._2 == m)
      .keys
      .toFastSeq
  }
} 
Example 2
Source File: FPTreeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {//增加转换
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {//合并树
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {//频繁项集的提取物
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 3
Source File: OutputDataStream.scala    From affinity   with Apache License 2.0 5 votes vote down vote up
package io.amient.affinity.core.util

import akka.util.Timeout
import io.amient.affinity.core.actor.TransactionCoordinator
import io.amient.affinity.core.serde.AbstractSerde
import io.amient.affinity.core.storage.{LogStorage, LogStorageConf, Record}

import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future}
import scala.language.{existentials, postfixOps}


object OutputDataStream {

  class TransactionCoordinatorNoop extends TransactionCoordinator {
    override def _begin(): Future[Unit] = Future.successful(())

    override def _commit(): Future[Unit] = Future.successful(())

    override def _abort(): Future[Unit] = Future.successful(())

    override def append(topic: String, key: Array[Byte], value: Array[Byte], timestamp: Option[Long], partition: Option[Int]): Future[_ <: Comparable[_]] = {
      Future.successful(0L)
    }
  }

  //create OutputDataStream without transactional support
  def apply[K, V](keySerde: AbstractSerde[_ >: K], valSerde: AbstractSerde[_ >: V], conf: LogStorageConf): OutputDataStream[K, V] = {
    new OutputDataStream[K, V](new TransactionCoordinatorNoop, keySerde, valSerde, conf)
  }

}

class OutputDataStream[K, V] private[affinity](txn: TransactionCoordinator, keySerde: AbstractSerde[_ >: K], valSerde: AbstractSerde[_ >: V], conf: LogStorageConf) {

  lazy val storage = LogStorage.newInstanceEnsureExists(conf)

  lazy private val topic: String = storage.getTopic()

  implicit val timeout = Timeout(1 minute) //FIXME

  def append(record: Record[K, V]): Future[_ <: Comparable[_]] = {
    if (txn.inTransaction()) {
      txn.append(topic, keySerde.toBytes(record.key), valSerde.toBytes(record.value), Option(record.timestamp), None)
    } else {
      val binaryRecord = new Record(keySerde.toBytes(record.key), valSerde.toBytes(record.value), record.timestamp)
      val jf = storage.append(binaryRecord)
      Future(jf.get)(ExecutionContext.Implicits.global)
    }
  }

  def delete(key: K): Future[_ <: Comparable[_]] = {
    if (txn.inTransaction()) {
      txn.append(topic, keySerde.toBytes(key), null, None, None)
    } else {
      val jf = storage.delete(keySerde.toBytes(key))
      Future(jf.get)(ExecutionContext.Implicits.global)
    }
  }

  def flush(): Unit = storage.flush()

  def close(): Unit = {
    try flush() finally try storage.close() finally {
      keySerde.close()
      valSerde.close()
    }
  }
} 
Example 4
Source File: EvalConfig.scala    From aerosolve   with Apache License 2.0 5 votes vote down vote up
package com.airbnb.common.ml.strategy.config

import scala.language.existentials

import com.typesafe.config.Config

import com.airbnb.common.ml.strategy.data.TrainingData
import com.airbnb.common.ml.util.ScalaLogging


case class EvalConfig(
    trainingConfig: TrainingConfig,
    evalDataQuery: String,
    holdoutDataQuery: String
)

object DirectQueryEvalConfig extends ScalaLogging {

  def loadConfig[T](
      config: Config
  ): EvalConfig = {
    val evalDataQuery = config.getString("eval_data_query")
    val holdoutDataQuery = config.getString("holdout_data_query")

    logger.info(s"Eval Data Query: $evalDataQuery")

    EvalConfig(
      TrainingConfig.loadConfig(config),
      evalDataQuery,
      holdoutDataQuery)
  }
} 
Example 5
Source File: ShuffleMapTask.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.nio.ByteBuffer

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter


  def this(partitionId: Int) {
    this(0, null, new Partition { override def index: Int = 0 }, null)
  }

  @transient private val preferredLocs: Seq[TaskLocation] = {
    if (locs == null) Nil else locs.toSet.toSeq
  }

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

  override def preferredLocations: Seq[TaskLocation] = preferredLocs

  override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
} 
Example 6
Source File: FPTreeSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 7
Source File: NettyBlockRpcServer.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}


class NettyBlockRpcServer(
    appId: String,
    serializer: Serializer,
    blockManager: BlockDataManager)
  extends RpcHandler with Logging {

  private val streamManager = new OneForOneStreamManager()

  override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocks: Seq[ManagedBuffer] =
          openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

      case uploadBlock: UploadBlock =>
        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
        val (level: StorageLevel, classTag: ClassTag[_]) = {
          serializer
            .newInstance()
            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
            .asInstanceOf[(StorageLevel, ClassTag[_])]
        }
        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
        val blockId = BlockId(uploadBlock.blockId)
        blockManager.putBlockData(blockId, data, level, classTag)
        responseContext.onSuccess(ByteBuffer.allocate(0))
    }
  }

  override def getStreamManager(): StreamManager = streamManager
} 
Example 8
Source File: FPTreeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 9
Source File: Casts.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.expr.ir

import is.hail.asm4s._
import is.hail.types._
import is.hail.types.virtual._

import scala.language.existentials

object Casts {
  private val casts: Map[(Type, Type), (Code[T] => Code[_]) forSome {type T}] = Map(
    (TInt32, TInt32) -> ((x: Code[Int]) => x),
    (TInt32, TInt64) -> ((x: Code[Int]) => x.toL),
    (TInt32, TFloat32) -> ((x: Code[Int]) => x.toF),
    (TInt32, TFloat64) -> ((x: Code[Int]) => x.toD),
    (TInt64, TInt32) -> ((x: Code[Long]) => x.toI),
    (TInt64, TInt64) -> ((x: Code[Long]) => x),
    (TInt64, TFloat32) -> ((x: Code[Long]) => x.toF),
    (TInt64, TFloat64) -> ((x: Code[Long]) => x.toD),
    (TFloat32, TInt32) -> ((x: Code[Float]) => x.toI),
    (TFloat32, TInt64) -> ((x: Code[Float]) => x.toL),
    (TFloat32, TFloat32) -> ((x: Code[Float]) => x),
    (TFloat32, TFloat64) -> ((x: Code[Float]) => x.toD),
    (TFloat64, TInt32) -> ((x: Code[Double]) => x.toI),
    (TFloat64, TInt64) -> ((x: Code[Double]) => x.toL),
    (TFloat64, TFloat32) -> ((x: Code[Double]) => x.toF),
    (TFloat64, TFloat64) -> ((x: Code[Double]) => x),
    (TInt32, TCall) -> ((x: Code[Int]) => x))

  def get(from: Type, to: Type): Code[_] => Code[_] =
    casts(from -> to).asInstanceOf[Code[_] => Code[_]]

  def valid(from: Type, to: Type): Boolean =
    casts.contains(from -> to)
} 
Example 10
Source File: BinarySearch.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.expr.ir

import is.hail.annotations.{CodeOrdering, Region}
import is.hail.asm4s._
import is.hail.types.physical._
import is.hail.utils.FastIndexedSeq

import scala.language.existentials

class BinarySearch[C](mb: EmitMethodBuilder[C], typ: PContainer, eltType: PType, keyOnly: Boolean) {

  val elt: PType = typ.elementType
  val ti: TypeInfo[_] = typeToTypeInfo(elt)

  val (compare: CodeOrdering.F[Int], equiv: CodeOrdering.F[Boolean], findElt: EmitMethodBuilder[C], t: PType) = if (keyOnly) {
    val ttype = elt match {
      case t: PBaseStruct =>
        require(t.size == 2)
        t
      case t: PInterval => t.representation.asInstanceOf[PStruct]
    }
    val kt = ttype.types(0)
    val findMB = mb.genEmitMethod("findElt", FastIndexedSeq[ParamType](typeInfo[Long], typeInfo[Boolean], typeToTypeInfo(kt)), typeInfo[Int])
    val mk2l = findMB.newLocal[Boolean]()
    val mk2l1 = mb.newLocal[Boolean]()

    val comp: CodeOrdering.F[Int] = {
      case ((mk1: Code[Boolean], k1: Code[_]), (m2: Code[Boolean], v2: Code[Long] @unchecked)) =>
        Code.memoize(v2, "bs_comp_v2") { v2 =>
          val mk2 = Code(mk2l := m2 || ttype.isFieldMissing(v2, 0), mk2l)
          val k2 = mk2l.mux(defaultValue(kt), Region.loadIRIntermediate(kt)(ttype.fieldOffset(v2, 0)))
          findMB.getCodeOrdering(eltType, kt, CodeOrdering.Compare())((mk1, k1), (mk2, k2))
        }
    }
    val ceq: CodeOrdering.F[Boolean] = {
      case ((mk1: Code[Boolean], k1: Code[_]), (m2: Code[Boolean], v2: Code[Long] @unchecked)) =>
        Code.memoize(v2, "bs_comp_v2") { v2 =>
          val mk2 = Code(mk2l1 := m2 || ttype.isFieldMissing(v2, 0), mk2l1)
          val k2 = mk2l1.mux(defaultValue(kt), Region.loadIRIntermediate(kt)(ttype.fieldOffset(v2, 0)))
          mb.getCodeOrdering(eltType, kt, CodeOrdering.Equiv())((mk1, k1), (mk2, k2))
        }
    }
    (comp, ceq, findMB, kt)
  } else
    (mb.getCodeOrdering(eltType, elt, CodeOrdering.Compare()),
      mb.getCodeOrdering(eltType, elt, CodeOrdering.Equiv()),
      mb.genEmitMethod("findElt", FastIndexedSeq[ParamType](typeInfo[Long], typeInfo[Boolean], elt.ti), typeInfo[Int]), elt)

  private[this] val array = findElt.getCodeParam[Long](1)
  private[this] val m = findElt.getCodeParam[Boolean](2)
  private[this] val e = findElt.getCodeParam(3)(t.ti)
  private[this] val len = findElt.newLocal[Int]()
  private[this] val i = findElt.newLocal[Int]()
  private[this] val low = findElt.newLocal[Int]()
  private[this] val high = findElt.newLocal[Int]()

  def cmp(i: Code[Int]): Code[Int] =
    Code.memoize(i, "binsearch_cmp_i") { i =>
      compare((m, e),
        (typ.isElementMissing(array, i),
          Region.loadIRIntermediate(elt)(typ.elementOffset(array, len, i))))
    }

  // Returns smallest i, 0 <= i < n, for which a(i) >= key, or returns n if a(i) < key for all i
  findElt.emit(Code(
    len := typ.loadLength(array),
    low := 0,
    high := len,
    Code.whileLoop(low < high,
      i := (low + high) / 2,
      (cmp(i) <= 0).mux(
        high := i,
        low := i + 1)),
    low))

  // check missingness of v before calling
  def getClosestIndex(array: Code[Long], m: Code[Boolean], v: Code[_]): Code[Int] = {
    findElt.invokeCode[Int](array, m, v)
  }
} 
Example 11
Source File: PrunedScanSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import scala.language.existentials

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
//PrunedScan 可以指定列,其他的列数据源可以不用返回
class PrunedScanSource extends RelationProvider {//提供关系
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
  }
}

case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedScan {

  override def schema: StructType =
    StructType(//StructType代表一张表,StructField代表一个字段
      StructField("a", IntegerType, nullable = false) ::
      StructField("b", IntegerType, nullable = false) :: Nil)

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val rowBuilders = requiredColumns.map {
      case "a" => (i: Int) => Seq(i)
      case "b" => (i: Int) => {
        //println(">>>>>>>"+i * 2)
        Seq(i * 2)
      }
    }
    //parallelize 分区数
    sqlContext.sparkContext.parallelize(from to to).map(i =>
      Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
  }
}

class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
        |CREATE TEMPORARY TABLE oneToTenPruned
        |USING org.apache.spark.sql.sources.PrunedScanSource
        |OPTIONS (
        |  from '1',
        |  to '10'
        |)
      """.stripMargin)
     
  }

  def testPruning(sqlString: String, expectedColumns: String*): Unit = {
    test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
      val queryExecution = sql(sqlString).queryExecution
      val rawPlan = queryExecution.executedPlan.collect {
        case p: execution.PhysicalRDD => p
      } match {
        case Seq(p) => p
        case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
      }
      val rawColumns = rawPlan.output.map(_.name)
      val rawOutput = rawPlan.execute().first()

      if (rawColumns != expectedColumns) {
        fail(
          s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" +
          s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" +
            queryExecution)
      }

      if (rawOutput.numFields != expectedColumns.size) {
        fail(s"Wrong output row. Got $rawOutput\n$queryExecution")
      }
    }
  }

} 
Example 12
Source File: SprayUtilities.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.nbtest

import spray.json.{JsArray, JsObject, JsValue, JsonFormat}

import scala.language.{existentials, implicitConversions}

abstract class SprayOp

case class IndexOp(item: Int) extends SprayOp

case class FieldOp(value: String) extends SprayOp

class SprayUtility(val json: JsValue) {

  private def parseQuery(q: String): List[SprayOp] = {
    q.split("." (0)).flatMap { t =>
      if (t.contains("]") & t.contains("]")) {
        t.split("][".toCharArray).filter(_.length > 0).toSeq match {
          case Seq(index) => Seq(IndexOp(index.toInt))
          case Seq(field, index) => Seq(FieldOp(field), IndexOp(index.toInt))
        }
      } else if (!t.contains("]") & !t.contains("]")) {
        Seq(FieldOp(t)).asInstanceOf[List[SprayOp]]
      } else {
        throw new IllegalArgumentException(s"Cannot parse query: $q")
      }
    }.toList
  }

  private def selectInternal[T](json: JsValue, ops: List[SprayOp])(implicit format: JsonFormat[T]): T = {
    ops match {
      case Nil => json.convertTo[T]
      case IndexOp(i) :: tail =>
        selectInternal[T](json.asInstanceOf[JsArray].elements(i), tail)
      case FieldOp(f) :: tail =>
        selectInternal[T](json.asInstanceOf[JsObject].fields(f), tail)
      case _ => throw new MatchError("This code should be unreachable")
    }
  }

  def select[T](query: String)(implicit format: JsonFormat[T]): T = {
    selectInternal[T](json, parseQuery(query))
  }
}

object SprayImplicits {
  implicit def sprayUtilityConverter(s: JsValue): SprayUtility = new SprayUtility(s)

  implicit def sprayUtilityConversion(s: SprayUtility): JsValue = s.json
} 
Example 13
Source File: NotebookTests.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.nbtest
//TODO temp hack because ij picks up on it test classes by mistake

import java.util.concurrent.TimeUnit

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.nbtest.DatabricksUtilities._

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.language.existentials


class NotebookTests extends TestBase {

  test("Databricks Notebooks") {
    val clusterId = createClusterInPool(ClusterName, PoolId)
    try {
      println("Checking if cluster is active")
      tryWithRetries(Seq.fill(60*15)(1000).toArray){() =>
        assert(isClusterActive(clusterId))}
      println("Installing libraries")
      installLibraries(clusterId)
      tryWithRetries(Seq.fill(60*3)(1000).toArray){() =>
        assert(isClusterActive(clusterId))}
      println(s"Creating folder $Folder")
      workspaceMkDir(Folder)
      println(s"Submitting jobs")
      val jobIds = NotebookFiles.map(uploadAndSubmitNotebook(clusterId, _))
      println(s"Submitted ${jobIds.length} for execution: ${jobIds.toList}")
      try {
        val monitors = jobIds.map((runId: Int) => monitorJob(runId, TimeoutInMillis, logLevel = 2))
        println(s"Monitoring Jobs...")
        val failures = monitors
          .map(Await.ready(_, Duration(TimeoutInMillis.toLong, TimeUnit.MILLISECONDS)).value.get)
          .filter(_.isFailure)
        assert(failures.isEmpty)
      } catch {
        case t: Throwable =>
          jobIds.foreach { jid =>
            println(s"Cancelling job $jid")
            cancelRun(jid)
          }
          throw t
      }
    } finally {
      deleteCluster(clusterId)
    }
  }

  ignore("list running jobs for convenievce") {
    val obj = databricksGet("jobs/runs/list?active_only=true&limit=1000")
    println(obj)
  }

} 
Example 14
Source File: ScannerSpec.scala    From better-files   with MIT License 5 votes vote down vote up
package better.files

import Dsl._

import scala.language.existentials

class ScannerSpec extends CommonSpec {
  def t1 = File.newTemporaryFile()

  "splitter" should "split" in {
    val csvSplitter      = StringSplitter.on(',')
    def split(s: String) = csvSplitter.split(s).toList

    assert(split(",") === List("", ""))
    assert(split("") === List(""))
    assert(split("Hello World") === List("Hello World"))
    assert(split("Hello,World") === List("Hello", "World"))

    assert(split(",,") === List("", "", ""))
    assert(split(",Hello,World,") === List("", "Hello", "World", ""))
    assert(split(",Hello,World") === List("", "Hello", "World"))
    assert(split("Hello,World,") === List("Hello", "World", ""))
  }

  "scanner" should "parse files" in {
    val data = t1 << s"""
    | Hello World
    | 1 2 3
    | Ok 23 football
    """.stripMargin
    data.scanner() foreach { scanner =>
      assert(scanner.lineNumber() == 0)
      assert(scanner.next[String] == "Hello")
      assert(scanner.lineNumber() == 2)
      assert(scanner.next[String] == "World")
      assert(scanner.next[Int] == 1)
      assert(scanner.next[Int] == 2)
      assert(scanner.lineNumber() == 3)
      assert(scanner.next[Int] == 3)
      assert(scanner.nextLine() == " Ok 23 football")
      assert(!scanner.hasNext)
      a[NoSuchElementException] should be thrownBy scanner.next()
      a[NoSuchElementException] should be thrownBy scanner.nextLine()
      assert(!scanner.hasNext)
    }
    data.tokens().toSeq shouldEqual data.newScanner().toSeq
  }

  it should "parse longs/booleans" in {
    val data = for {
      scanner <- Scanner("10 false").autoClosed
    } yield scanner.next[(Long, Boolean)]
    data.get() shouldBe ((10L, false))
  }

  it should "parse custom parsers" in {
    val file = t1 < """
      |Garfield
      |Woofer
    """.stripMargin

    sealed trait Animal
    case class Dog(name: String) extends Animal
    case class Cat(name: String) extends Animal

    implicit val animalParser: Scannable[Animal] = Scannable { scanner =>
      val name = scanner.next[String]
      if (name == "Garfield") Cat(name) else Dog(name)
    }
    file.scanner() foreach { scanner =>
      Seq.fill(2)(scanner.next[Animal]) should contain theSameElementsInOrderAs Seq(Cat("Garfield"), Dog("Woofer"))
    }
  }

  it should "parse empty tokens" in {
    val scanner = Scanner("hello||world", StringSplitter.on('|'))
    List.fill(3)(scanner.next[Option[String]]) shouldEqual List(Some("hello"), None, Some("world"))
  }
} 
Example 15
Source File: LabelsSelectize.scala    From ProductWebUI   with Apache License 2.0 5 votes vote down vote up
package synereo.client.components


import shared.models.Label
import synereo.client.services.SYNEREOCircuit

import scala.language.existentials
import japgolly.scalajs.react._
import japgolly.scalajs.react.vdom.prefix_<^._
import org.querki.jquery._
import org.scalajs.dom._
import synereo.client.facades.SynereoSelectizeFacade

import scala.language.existentials
import scala.scalajs.js

)(
        <.option(^.value := "")("Select"),
        //          props.proxy().render(searchesRootModel => searchesRootModel.se)
        for (label <- SYNEREOCircuit.zoom(_.searches.searchesModel).value) yield {
          <.option(^.value := label.text, ^.key := label.uid)(s"#${label.text}")
        })

    }
  }

  val component = ReactComponentB[Props]("LabelsSelectize")
    .initialState(State())
    .renderBackend[Backend]
    .componentDidMount(scope => scope.backend.mounted(scope.props))
    .build

  def apply(props: Props) = component(props)
} 
Example 16
Source File: ConnectionsLabelsSelectize.scala    From ProductWebUI   with Apache License 2.0 5 votes vote down vote up
package synereo.client.components

import japgolly.scalajs.react._
import japgolly.scalajs.react.vdom.prefix_<^._
import org.querki.jquery._
import org.scalajs.dom._
import shared.dtos.Connection
import synereo.client.facades.SynereoSelectizeFacade
import synereo.client.services.SYNEREOCircuit

import scala.language.existentials
import scala.scalajs.js


//scalastyle:off
object ConnectionsLabelsSelectize {
  def getCnxnsAndLabelsFromSelectize(selectizeInputId: String): (Seq[Connection], Seq[String]) = {
    var selectedConnections = Seq[Connection]()
    var selectedLabels = Seq[String]()
    val selector: js.Object = s"#${selectizeInputId} > .selectize-control> .selectize-input > div"

    $(selector).each((y: Element) => {
      val dataVal = $(y).attr("data-value").toString
      try {
        val cnxn = upickle.default.read[Connection](dataVal)
        selectedConnections :+= cnxn
      } catch {
        case e: Exception =>
          selectedLabels :+= dataVal
      }
    })
    (selectedConnections, selectedLabels)
  }

  def filterLabelStrings(value: Seq[String], character: String): Seq[String] = {
    value
      .filter(e => e.charAt(0) == "#" && e.count(_ == character) == 1)
      .map(_.replace(character, "")).distinct

  }

  case class Props(parentIdentifier: String)

  case class State(maxItems: Int = 7,
                   maxCharLimit: Int = 16,
                   allowNewItemsCreation: Boolean = false)

  case class Backend(t: BackendScope[Props, State]) {
    def initializeTagsInput(): Unit = {
      val state = t.state.runNow()
      val parentIdentifier = t.props.runNow().parentIdentifier
      SynereoSelectizeFacade.initilizeSelectize(s"${parentIdentifier}-selectize", state.maxItems, state.maxCharLimit, state.allowNewItemsCreation)
    }

    def mounted(props: Props): Callback = Callback {
      initializeTagsInput()
    }


    def render(props: Props, state: State) = {
      <.select(^.className := "select-state", ^.id := s"${props.parentIdentifier}-selectize",
        ^.className := "demo-default", ^.placeholder := "Search e.g. @Synereo or #fun")(
        <.option(^.value := "")("Select"),
        for (connection <- SYNEREOCircuit.zoom(_.connections).value.connectionsResponse) yield <.option(^.value := upickle.default.write(connection.connection),
          ^.key := connection.connection.target)(s"@${connection.name}"),
        for (label <- SYNEREOCircuit.zoom(_.searches).value.searchesModel) yield
          <.option(^.value := label.text, ^.key := label.uid)(s"#${label.text}")
      )

    }
  }

  val component = ReactComponentB[Props]("SearchesConnectionList")
    .initialState(State())
    .renderBackend[Backend]
    .componentDidMount(scope => scope.backend.mounted(scope.props))
    .build

  def apply(props: Props) = component(props)
} 
Example 17
Source File: UserPersona.scala    From ProductWebUI   with Apache License 2.0 5 votes vote down vote up
package synereo.client.components

import diode.react.ModelProxy
import japgolly.scalajs.react.{ReactComponentB, _}
import japgolly.scalajs.react.vdom.prefix_<^._
import shared.models.UserModel
import synereo.client.css.{NewMessageCSS, SynereoCommanStylesCSS}
import scala.language.existentials
import scalacss.ScalaCssReact._


//scalastyle:off
object UserPersona {

  def getPersona(): String = {
    ""
  }

  case class Props(proxy: ModelProxy[UserModel])

  case class Backend(t: BackendScope[Props, _]) {


    def mounted(props: Props): Callback = Callback {
      //      println("UserPersona is : " + props.proxy.value)
    }

    def render(props: Props) = {
      val model = props.proxy.value
      <.div(^.className := "row", NewMessageCSS.Style.PersonaContainerDiv)(
        <.div(^.className := "col-md-2 col-sm-2 col-xs-2", SynereoCommanStylesCSS.Style.paddingLeftZero)(
          <.img(^.alt := "userImage", ^.src := model.imgSrc, ^.className := "img-responsive", NewMessageCSS.Style.userImage)
        ),
        <.div(^.className := "col-md-10", SynereoCommanStylesCSS.Style.paddingLeftZero, SynereoCommanStylesCSS.Style.paddingRightZero)(
          <.div(
            <.button(^.className := "btn", ^.`type` := "button", NewMessageCSS.Style.changePersonaBtn)("Change posting persona", <.span(^.className := "caret", ^.color.blue)),
            <.div(^.className := "pull-right hidden-xs")(MIcon.apply("more_vert", "24"))
          )
        ),
        <.div(NewMessageCSS.Style.userNameOnDilogue)(
          <.div(model.name, <.span(Icon.chevronRight), "public", <.span(Icon.share))
        )
      )
    }
  }

  val component = ReactComponentB[Props]("UserPersona")
    .renderBackend[Backend]
    .componentDidMount(scope => scope.backend.mounted(scope.props))
    .build

  def apply(props: Props) = component(props)

} 
Example 18
Source File: LabelsSelectize.scala    From ProductWebUI   with Apache License 2.0 5 votes vote down vote up
package client.components

import client.utils.LabelsUtils
import diode.react.ModelProxy
import japgolly.scalajs.react._
import japgolly.scalajs.react.vdom.prefix_<^._
import org.denigma.selectize._
import org.querki.jquery._
import org.scalajs.dom._
import client.rootmodel.SearchesRootModel
import shared.models.Label
import client.sessionitems.SessionItems

import scala.collection.mutable.ListBuffer
import scala.language.existentials
import scala.scalajs.js

object LabelsSelectize {

  def getLabelsTxtFromSelectize(selectizeInputId: String): Seq[String] = {
    var selectedLabels = Seq[String]()
    val selector: js.Object = s"#${selectizeInputId} > .selectize-control> .selectize-input > div"
    if ($(selector).length > 0) {
      $(selector).each((y: Element) => selectedLabels :+= $(y).attr("data-value").toString)
    } else {
      selectedLabels = Nil
    }

    selectedLabels
  }

  def getLabelsFromSelectizeInput(selectizeInputId: String): Seq[Label] = {
    var selectedLabels = Seq[Label]()
    val selector: js.Object = s"#${selectizeInputId} > .selectize-control> .selectize-input > div"

    $(selector).each((y: Element) => selectedLabels :+= upickle.default.read[Label]($(y).attr("data-value").toString))
    selectedLabels
  }

  var getSelectedValue = new ListBuffer[String]()

  
  case class Props(proxy: ModelProxy[SearchesRootModel], parentIdentifier: String)

  case class Backend(t: BackendScope[Props, _]) {
    def initializeTagsInput(parentIdentifier: String): Unit = {
      val selectState: js.Object = s"#$parentIdentifier > .selectize-control"
      //      println(s"element lenth: ${$(selectState).length}")
      if ($(selectState).length < 1) {
        val selectizeInput: js.Object = "#labelsSelectize"
        //        $(selectizeInput).selectize(SelectizeConfig.maxOptions(2)).destroy()
        //        println(s"test : ${$(selectizeInput)}")
        $(selectizeInput).selectize(SelectizeConfig
          .create(true)
          .maxItems(3)
          .plugins("remove_button"))
      }

    }

    def getSelectedValues = Callback {
      val selectState: js.Object = "#selectize"
      val getSelectedValue = $(selectState).find("option").text()
      //scalastyle:off
      //      println(getSelectedValue)
    }

    def mounted(props: Props): Callback = Callback {
      //      println("searches model is = " + props.proxy().searchesModel)
      initializeTagsInput(props.parentIdentifier)
    }

    def render(props: Props) = {
      val parentDiv: js.Object = s"#${props.parentIdentifier}"
      //      println(s"parent div length ${$(parentDiv).length}")
      if ($(parentDiv).length == 0) {
        <.select(^.className := "select-state", ^.id := "labelsSelectize", ^.className := "demo-default", ^.placeholder := "select #label(s)", ^.onChange --> getSelectedValues)(
          <.option(^.value := "")("Select"),
          //          props.proxy().render(searchesRootModel => searchesRootModel.se)
          for (label <- props.proxy().searchesModel
            .filter(e => e.parentUid == "self")
            .filterNot(e => LabelsUtils.getSystemLabels().contains(e.text))) yield {
            <.option(^.value := upickle.default.write(label), ^.key := label.uid)(label.text)
          }
        )
      } else {
        <.div()
      }
    }
  }

  val component = ReactComponentB[Props]("SearchesConnectionList")
    .renderBackend[Backend]
    .componentDidMount(scope => scope.backend.mounted(scope.props))
    .build

  def apply(props: Props) = component(props)
} 
Example 19
Source File: Dashboard.scala    From scalajs-spa-tutorial   with Apache License 2.0 5 votes vote down vote up
package spatutorial.client.modules

import diode.data.Pot
import diode.react._
import japgolly.scalajs.react._
import japgolly.scalajs.react.extra.router.RouterCtl
import japgolly.scalajs.react.vdom.html_<^._
import spatutorial.client.SPAMain.{Loc, TodoLoc}
import spatutorial.client.components._

import scala.util.Random
import scala.language.existentials

object Dashboard {

  case class Props(router: RouterCtl[Loc], proxy: ModelProxy[Pot[String]])

  case class State(motdWrapper: ReactConnectProxy[Pot[String]])

  // create dummy data for the chart
  val cp = Chart.ChartProps(
    "Test chart",
    Chart.BarChart,
    ChartData(
      Random.alphanumeric.map(_.toUpper.toString).distinct.take(10),
      Seq(ChartDataset(Iterator.continually(Random.nextDouble() * 10).take(10).toSeq, "Data1"))
    )
  )

  // create the React component for Dashboard
  private val component = ScalaComponent.builder[Props]("Dashboard")
    // create and store the connect proxy in state for later use
    .initialStateFromProps(props => State(props.proxy.connect(m => m)))
    .renderPS { (_, props, state) =>
      <.div(
        // header, MessageOfTheDay and chart components
        <.h2("Dashboard"),
        state.motdWrapper(Motd(_)),
        Chart(cp),
        // create a link to the To Do view
        <.div(props.router.link(TodoLoc)("Check your todos!"))
      )
    }
    .build

  def apply(router: RouterCtl[Loc], proxy: ModelProxy[Pot[String]]) = component(Props(router, proxy))
} 
Example 20
Source File: ShuffleMapTask.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.nio.ByteBuffer

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter


  def this(partitionId: Int) {
    this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
  }

  @transient private val preferredLocs: Seq[TaskLocation] = {
    if (locs == null) Nil else locs.toSet.toSeq
  }

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

  override def preferredLocations: Seq[TaskLocation] = preferredLocs

  override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
} 
Example 21
Source File: AnyFormatSpec.scala    From scalapb-json4s   with Apache License 2.0 5 votes vote down vote up
package scalapb.json4s

import com.google.protobuf.any.{Any => PBAny}
import jsontest.anytests.{AnyTest, ManyAnyTest}
import org.json4s.jackson.JsonMethods._

import scala.language.existentials
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class AnyFormatSpec extends AnyFlatSpec with Matchers with JavaAssertions {
  val RawExample = AnyTest("test")

  val RawJson = parse(s"""{"field":"test"}""")

  val AnyExample = PBAny.pack(RawExample)

  val AnyJson = parse(
    s"""{"@type":"type.googleapis.com/jsontest.AnyTest","field":"test"}"""
  )

  val CustomPrefixAny = PBAny.pack(RawExample, "example.com/")

  val CustomPrefixJson = parse(
    s"""{"@type":"example.com/jsontest.AnyTest","field":"test"}"""
  )

  val ManyExample = ManyAnyTest(
    Seq(
      PBAny.pack(AnyTest("1")),
      PBAny.pack(AnyTest("2"))
    )
  )

  val ManyPackedJson = parse(
    """
      |{
      |  "@type": "type.googleapis.com/jsontest.ManyAnyTest",
      |  "fields": [
      |    {"@type": "type.googleapis.com/jsontest.AnyTest", "field": "1"},
      |    {"@type": "type.googleapis.com/jsontest.AnyTest", "field": "2"}
      |  ]
      |}
    """.stripMargin
  )

  override def registeredCompanions = Seq(AnyTest, ManyAnyTest)

  // For clarity
  def UnregisteredPrinter = JsonFormat.printer

  def UnregisteredParser = JsonFormat.parser

  "Any" should "fail to serialize if its respective companion is not registered" in {
    an[IllegalStateException] must be thrownBy UnregisteredPrinter.toJson(
      AnyExample
    )
  }

  "Any" should "fail to deserialize if its respective companion is not registered" in {
    a[JsonFormatException] must be thrownBy UnregisteredParser.fromJson[PBAny](
      AnyJson
    )
  }

  "Any" should "serialize correctly if its respective companion is registered" in {
    ScalaJsonPrinter.toJson(AnyExample) must be(AnyJson)
  }

  "Any" should "fail to serialize with a custom URL prefix if specified" in {
    an[IllegalStateException] must be thrownBy ScalaJsonPrinter.toJson(
      CustomPrefixAny
    )
  }

  "Any" should "fail to deserialize for a non-Google-prefixed type URL" in {
    a[JsonFormatException] must be thrownBy ScalaJsonParser.fromJson[PBAny](
      CustomPrefixJson
    )
  }

  "Any" should "deserialize correctly if its respective companion is registered" in {
    ScalaJsonParser.fromJson[PBAny](AnyJson) must be(AnyExample)
  }

  "Any" should "be serialized the same as in Java (and parsed back to original)" in {
    assertJsonIsSameAsJava(AnyExample)
  }

  "Any" should "resolve printers recursively" in {
    val packed = PBAny.pack(ManyExample)
    ScalaJsonPrinter.toJson(packed) must be(ManyPackedJson)
  }

  "Any" should "resolve parsers recursively" in {
    ScalaJsonParser.fromJson[PBAny](ManyPackedJson).unpack[ManyAnyTest] must be(
      ManyExample
    )
  }
} 
Example 22
Source File: AnyFormat.scala    From scalapb-json4s   with Apache License 2.0 5 votes vote down vote up
package scalapb.json4s

import com.google.protobuf.any.{Any => PBAny}
import org.json4s.JsonAST.{JNothing, JObject, JString, JValue}

import scala.language.existentials

object AnyFormat {
  val anyWriter: (Printer, PBAny) => JValue = {
    case (printer, any) =>
      // Find the companion so it can be used to JSON-serialize the message. Perhaps this can be circumvented by
      // including the original GeneratedMessage with the Any (at least in memory).
      val cmp = printer.typeRegistry
        .findType(any.typeUrl)
        .getOrElse(
          throw new IllegalStateException(
            s"Unknown type ${any.typeUrl} in Any.  Add a TypeRegistry that supports this type to the Printer."
          )
        )

      // Unpack the message...
      val message = any.unpack(cmp)

      // ... and add the @type marker to the resulting JSON
      printer.toJson(message) match {
        case JObject(fields) =>
          JObject(("@type" -> JString(any.typeUrl)) +: fields)
        case value =>
          // Safety net, this shouldn't happen
          throw new IllegalStateException(
            s"Message of type ${any.typeUrl} emitted non-object JSON: $value"
          )
      }
  }

  val anyParser: (Parser, JValue) => PBAny = {
    case (parser, obj @ JObject(fields)) =>
      obj \ "@type" match {
        case JString(typeUrl) =>
          val cmp = parser.typeRegistry
            .findType(typeUrl)
            .getOrElse(
              throw new JsonFormatException(
                s"Unknown type ${typeUrl} in Any.  Add a TypeRegistry that supports this type to the Parser."
              )
            )
          val message = parser.fromJson(obj, true)(cmp)
          PBAny(typeUrl = typeUrl, value = message.toByteString)

        case JNothing =>
          throw new JsonFormatException(s"Missing type url when parsing $obj")

        case unknown =>
          throw new JsonFormatException(
            s"Expected string @type field, got $unknown"
          )
      }

    case (_, unknown) =>
      throw new JsonFormatException(s"Expected an object, got $unknown")
  }
} 
Example 23
Source File: ModelSerializabilityTestBase.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha

import scala.language.existentials
import com.eharmony.aloha
import com.eharmony.aloha.models.{Model, SubmodelBase}
import org.junit.Assert._
import org.junit.Test
import org.reflections.Reflections

import scala.collection.JavaConversions.asScalaSet
import scala.util.Try
import java.lang.reflect.{Method, Modifier}

import com.eharmony.aloha.util.Logging


abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String])
extends Logging {

  def this() = this(pkgs = Seq(aloha.pkgName), Seq.empty)

  @Test def testSerialization(): Unit = {
    val ref = new Reflections(pkgs:_*)
    val submodels = ref.getSubTypesOf(classOf[SubmodelBase[_, _, _, _]]).toSeq
    val models = ref.getSubTypesOf(classOf[Model[_, _]]).toSeq

    val modelClasses =
      (models ++ submodels).
        filterNot { _.isInterface }.
        filterNot { c =>
          val name = c.getName
          outFilters.exists(name.matches)
        }

    if (modelClasses.isEmpty) {
      fail(s"No models found to test for Serializability in packages: ${pkgs.mkString(",")}")
    }
    else {
      debug {
        modelClasses
          .map(_.getCanonicalName)
          .mkString("Models tested for Serializability:\n\t", "\n\t", "")
      }
    }

    modelClasses.foreach { c =>
      val m = for {
        testClass  <- getTestClass(c.getCanonicalName)
        testMethod <- getTestMethod(testClass)
        method     <- ensureTestMethodIsTest(testMethod)
      } yield method

      m.left foreach fail
    }
  }

  private[this] implicit class RightMonad[L, R](e: Either[L, R]) {
    def flatMap[R1](f: R => Either[L, R1]) = e.right.flatMap(f)
    def map[R1](f: R => R1) = e.right.map(f)
  }

  private[this] def getTestClass(modelClassName: String) = {
    val testName = modelClassName + "Test"
    Try {
      Class.forName(testName)
    } map {
      Right(_)
    } getOrElse Left("No test class exists for " + modelClassName)
  }

  private[this] def getTestMethod(testClass: Class[_]) = {
    val testMethodName = "testSerialization"
    lazy val msg = s"$testMethodName doesn't exist in ${testClass.getCanonicalName}."
    Try {
      Option(testClass.getMethod(testMethodName))
    } map {
      case Some(m) => Right(m)
      case None => Left(msg)
    } getOrElse Left(msg)
  }

  private[this] def ensureTestMethodIsTest(method: Method) = {
    if (!Modifier.isPublic(method.getModifiers))
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} is not public")
    if (!method.getDeclaredAnnotations.exists(_.annotationType() == classOf[Test]))
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} does not have a @org.junit.Test annotation.")
    else if (method.getReturnType != classOf[Void] && method.getReturnType != classOf[Unit])
      Left(s"testSerialization in ${method.getDeclaringClass.getCanonicalName} is not a void function. It returns: ${method.getReturnType}")
    else Right(method)
  }
} 
Example 24
Source File: fields.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.semantics.compiled.plugin.schemabased.schema

import com.eharmony.aloha.reflect.RefInfo

import scala.language.existentials

// RECORD, ENUM, ARRAY, MAP, UNION, FIXED, STRING, BYTES, INT, LONG, FLOAT, DOUBLE, BOOLEAN, NULL;

sealed trait FieldDesc {
  def name: String
  def index: Int
  def nullable: Boolean
}

// TODO: Add additional types as necessary.

case class RecordField(name: String, index: Int, schema: Schema, refInfo: RefInfo[_], nullable: Boolean) extends FieldDesc
case class EnumField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class ListField(name: String, index: Int, elementType: FieldDesc, nullable: Boolean) extends FieldDesc
case class StringField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class IntField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class LongField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class FloatField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class DoubleField(name: String, index: Int, nullable: Boolean) extends FieldDesc
case class BooleanField(name: String, index: Int, nullable: Boolean) extends FieldDesc 
Example 25
Source File: HadoopUtils.scala    From spark-images   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.image

import java.nio.file.Paths

import org.apache.commons.io.FilenameUtils

import scala.sys.process._
import org.apache.hadoop.conf.{Configuration, Configured}
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.sql.SparkSession
import scala.language.existentials
import scala.util.Random

object RecursiveFlag {

  
  def setPathFilter(value: Option[Class[_]], sampleRatio: Option[Double] = None, spark: SparkSession)
  : Option[Class[_]] = {
    val flagName = FileInputFormat.PATHFILTER_CLASS
    val hadoopConf = spark.sparkContext.hadoopConfiguration
    val old = Option(hadoopConf.getClass(flagName, null))
    if (sampleRatio.isDefined) {
      hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio.get)
    } else {
      hadoopConf.unset(SamplePathFilter.ratioParam)
      None
    }

    value match {
      case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter])
      case None => hadoopConf.unset(flagName)
    }
    old
  }
} 
Example 26
Source File: JobUtils.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.job.util

import java.io.File
import java.nio.charset.Charset
import java.nio.file.{ Files, Path, StandardCopyOption }
import java.util.zip.ZipFile

import com.typesafe.scalalogging.StrictLogging
import helloscala.common.Configuration
import helloscala.common.util.{ DigestUtils, Utils }
import mass.common.util.FileUtils
import mass.core.job.JobConstants
import mass.job.JobSettings
import mass.message.job._
import mass.model.job.{ JobItem, JobTrigger }

import scala.concurrent.{ ExecutionContext, Future }


object JobUtils extends StrictLogging {
  case class JobZipInternal private (configs: Vector[JobCreateReq], entries: Vector[Path])

  def uploadJob(jobSettings: JobSettings, req: JobUploadJobReq)(implicit ec: ExecutionContext): Future[JobZip] =
    Future {
      val sha256 = DigestUtils.sha256HexFromPath(req.file)
      val dest = jobSettings.jobSavedDir.resolve(sha256.take(2)).resolve(sha256)

      val jobZipInternal = parseJobZip(req.file, req.charset, dest.resolve(JobConstants.DIST)) match {
        case Right(v) => v
        case Left(e)  => throw e
      }

      val zipPath = dest.resolve(req.fileName)
      Files.move(req.file, zipPath, StandardCopyOption.REPLACE_EXISTING)
      JobZip(zipPath, jobZipInternal.configs, jobZipInternal.entries)
    }

  @inline def parseJobZip(file: Path, charset: Charset, dest: Path): Either[Throwable, JobZipInternal] =
    parseJobZip(file.toFile, charset, dest)

  def parseJobZip(file: File, charset: Charset, dest: Path): Either[Throwable, JobZipInternal] = Utils.either {
    import scala.jdk.CollectionConverters._
    import scala.language.existentials

    val zip = new ZipFile(file, charset)
    try {
      val (confEntries, fileEntries) = zip
        .entries()
        .asScala
        .filterNot(entry => entry.isDirectory)
        .span(entry => entry.getName.endsWith(JobConstants.ENDS_SUFFIX) && !entry.isDirectory)
      val configs =
        confEntries.map(confEntry =>
          parseJobConf(FileUtils.getString(zip.getInputStream(confEntry), charset, "\n")) match {
            case Right(config) => config
            case Left(e)       => throw e
          })

      val buf = Array.ofDim[Byte](1024)
      val entryPaths = fileEntries.map { entry =>
        val entryName = entry.getName
        val savePath = dest.resolve(entryName)
        if (!Files.isDirectory(savePath.getParent)) {
          Files.createDirectories(savePath.getParent)
        }
        FileUtils.write(zip.getInputStream(entry), Files.newOutputStream(savePath), buf) // zip entry存磁盘
        savePath
      }

      JobZipInternal(configs.toVector, entryPaths.toVector)
    } finally {
      if (zip ne null) zip.close()
    }
  }

  def parseJobConf(content: String): Either[Throwable, JobCreateReq] = Utils.either {
    val conf = Configuration.parseString(content)
    val jobItem = JobItem(conf.getConfiguration("item"))
    val jobTrigger = JobTrigger(conf.getConfiguration("trigger"))
    JobCreateReq(conf.get[Option[String]]("key"), jobItem, jobTrigger)
  }
}

case class JobZip(zipPath: Path, configs: Vector[JobCreateReq], entries: Vector[Path]) 
Example 27
Source File: ScheduledTaskManager.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import scala.language.existentials
import java.util.concurrent._
import java.util.UUID
import com.google.common.util.concurrent.ThreadFactoryBuilder
import ScheduledTaskManager._
import scala.util.Try


  def stop() = {
    _taskMap.clear()
    _scheduler.shutdown()
  }
}

object ScheduledTaskManager {
  val DefaultMaxThreads = 4
  val DefaultExecutionDelay = 10 // 10 milliseconds
  val DefaultTimeInterval = 100 // 100 milliseconds
} 
Example 28
Source File: TipTestSuite.scala    From inox   with Apache License 2.0 5 votes vote down vote up
package inox
package tip

import solvers._

import scala.language.existentials

class TipTestSuite extends TestSuite with ResourceUtils {

  override def configurations = Seq(
    Seq(optSelectedSolvers(Set("nativez3")), optCheckModels(true)),
    Seq(optSelectedSolvers(Set("smt-z3")),   optCheckModels(true)),
    Seq(optSelectedSolvers(Set("smt-cvc4")), optCheckModels(true)),
    Seq(optSelectedSolvers(Set("smt-z3")),   optCheckModels(true), optAssumeChecked(true))
  )

  override protected def optionsString(options: Options): String = {
    "solver=" + options.findOptionOrDefault(optSelectedSolvers).head +
    (if (options.findOptionOrDefault(optAssumeChecked)) " assumechecked" else "")
  }

  private def ignoreSAT(ctx: Context, file: java.io.File): FilterStatus = 
    ctx.options.findOptionOrDefault(optSelectedSolvers).headOption match {
      case Some(solver) => (solver, file.getName) match {
        // test containing list of booleans, so CVC4 will crash on this
        // See http://church.cims.nyu.edu/bugzilla3/show_bug.cgi?id=500
        case ("smt-cvc4", "List-fold.tip") => Skip
        // Z3 and CVC4 binaries are exceedingly slow on these benchmarks
        case ("smt-z3" | "smt-cvc4", "BinarySearchTreeQuant.scala-2.tip") => Ignore
        case ("smt-z3" | "smt-cvc4", "ForallAssoc.scala-0.tip") => Ignore
        // this test only holds when assumeChecked=false
        case (_, "LambdaEquality2.scala-1.tip") if ctx.options.findOptionOrDefault(optAssumeChecked) => Skip
        case _ => Test
      }

      case _ => Test
    }

  private def ignoreUNSAT(ctx: Context, file: java.io.File): FilterStatus = 
    ctx.options.findOptionOrDefault(optSelectedSolvers).headOption match {
      case Some(solver) => (solver, file.getName) match {
        // Z3 binary will predictably segfault on certain permutations of this problem
        case ("smt-z3", "MergeSort2.scala-1.tip") => Ignore
        // use non-linear operators that aren't supported in CVC4
        case ("smt-cvc4", "Instantiation.scala-0.tip") => Skip
        case ("smt-cvc4", "LetsInForall.tip") => Skip
        case ("smt-cvc4", "Weird.scala-0.tip") => Skip
        // this test only holds when assumeChecked=true
        case (_, "QuickSortFilter.scala-1.tip") if !ctx.options.findOptionOrDefault(optAssumeChecked) => Skip
        case _ => Test
      }
      case _ => Test
    }

  private def ignoreUNKNOWN(ctx: Context, file: java.io.File): FilterStatus = 
    ctx.options.findOptionOrDefault(optSelectedSolvers).headOption match {
      case Some(solver) => (solver, file.getName) match {
        // non-linear operations are too slow on smt-z3
        case ("smt-z3", "Soundness2.scala-0.tip") => Ignore
        // use non-linear operators that aren't supported in CVC4
        case ("smt-cvc4", "Soundness.scala-0.tip") => Skip
        case ("smt-cvc4", "Soundness2.scala-0.tip") => Skip
        case _ => Test
      }
      case _ => Test
    }

  for (file <- resourceFiles("regression/tip/SAT", _.endsWith(".tip"))) {
    test(s"SAT - ${file.getName}", ignoreSAT(_, file)) { implicit ctx =>
      for ((program, expr) <- Parser(file).parseScript) {
        assert(SimpleSolverAPI(program.getSolver).solveSAT(expr).isSAT)
      }
    }
  }

  for (file <- resourceFiles("regression/tip/UNSAT", _.endsWith(".tip"))) {
    test(s"UNSAT - ${file.getName}", ignoreUNSAT(_, file)) { implicit ctx =>
      for ((program, expr) <- Parser(file).parseScript) {
        assert(SimpleSolverAPI(program.getSolver).solveSAT(expr).isUNSAT)
      }
    }
  }

  for (file <- resourceFiles("regression/tip/UNKNOWN", _.endsWith(".tip"))) {
    test(s"UNKNOWN - ${file.getName}", ignoreUNKNOWN(_, file)) { ctx0 =>
      implicit val ctx = ctx0.copy(options = ctx0.options + optCheckModels(false))
      for ((program, expr) <- Parser(file).parseScript) {
        val api = SimpleSolverAPI(program.getSolver)
        val res = api.solveSAT(expr)
        assert(!res.isSAT && !res.isUNSAT)
        assert(ctx.reporter.errorCount > 0)
      }
    }
  }
} 
Example 29
Source File: string_formats_yaml.base.scala    From play-swagger   with MIT License 5 votes vote down vote up
package string_formats.yaml

import scala.language.existentials

import play.api.mvc.{Action, Controller, Results}
import play.api.http._
import Results.Status

import de.zalando.play.controllers.{PlayBodyParsing, ParsingError, ResultWrapper}
import PlayBodyParsing._
import scala.util._
import de.zalando.play.controllers.Base64String
import Base64String._
import de.zalando.play.controllers.BinaryString
import BinaryString._
import org.joda.time.DateTime
import java.util.UUID
import org.joda.time.LocalDate

import de.zalando.play.controllers.PlayPathBindables





trait String_formatsYamlBase extends Controller with PlayBodyParsing {
    sealed trait GetType[T] extends ResultWrapper[T]
    
    case object Get200 extends EmptyReturn(200)
    

    private type getActionRequestType       = (GetDate_time, GetDate, GetBase64, GetUuid, BinaryString)
    private type getActionType[T]            = getActionRequestType => GetType[T] forSome { type T }

        private def getParser(acceptedTypes: Seq[String], maxLength: Int = parse.DefaultMaxTextLength) = {
            def bodyMimeType: Option[MediaType] => String = mediaType => {
                val requestType = mediaType.toSeq.map {
                    case m: MediaRange => m
                    case MediaType(a,b,c) => new MediaRange(a,b,c,None,Nil)
                }
                negotiateContent(requestType, acceptedTypes).orElse(acceptedTypes.headOption).getOrElse("application/json")
            }
            
            import de.zalando.play.controllers.WrappedBodyParsers
            
            val customParsers = WrappedBodyParsers.anyParser[BinaryString]
            anyParser[BinaryString](bodyMimeType, customParsers, "Invalid BinaryString", maxLength)
        }

    val getActionConstructor  = Action
    def getAction[T] = (f: getActionType[T]) => (date_time: GetDate_time, date: GetDate, base64: GetBase64, uuid: GetUuid) => getActionConstructor(getParser(Seq[String]())) { request =>
        val providedTypes = Seq[String]("application/json", "application/yaml")

        negotiateContent(request.acceptedTypes, providedTypes).map { getResponseMimeType =>

            val petId = request.body
            
            

                val result =
                        new GetValidator(date_time, date, base64, uuid, petId).errors match {
                            case e if e.isEmpty => processValidgetRequest(f)((date_time, date, base64, uuid, petId))(getResponseMimeType)
                            case l =>
                                implicit val marshaller: Writeable[Seq[ParsingError]] = parsingErrors2Writable(getResponseMimeType)
                                BadRequest(l)
                        }
                result
            
        }.getOrElse(Status(406)("The server doesn't support any of the requested mime types"))
    }

    private def processValidgetRequest[T](f: getActionType[T])(request: getActionRequestType)(mimeType: String) = {
      f(request).toResult(mimeType).getOrElse {
        Results.NotAcceptable
      }
    }
    abstract class EmptyReturn(override val statusCode: Int = 204) extends ResultWrapper[Results.EmptyContent]  with GetType[Results.EmptyContent] { val result = Results.EmptyContent(); val writer = (x: String) => Some(new DefaultWriteables{}.writeableOf_EmptyContent); override def toResult(mimeType: String): Option[play.api.mvc.Result] = Some(Results.NoContent) }
    case object NotImplementedYet extends ResultWrapper[Results.EmptyContent]  with GetType[Results.EmptyContent] { val statusCode = 501; val result = Results.EmptyContent(); val writer = (x: String) => Some(new DefaultWriteables{}.writeableOf_EmptyContent); override def toResult(mimeType: String): Option[play.api.mvc.Result] = Some(Results.NotImplemented) }
} 
Example 30
Source File: FPTreeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 31
Source File: DataRow.scala    From flink-elasticsearch-source-connector   with Apache License 2.0 5 votes vote down vote up
package com.mnubo.flink.streaming.connectors

import org.apache.commons.lang3.ClassUtils
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor

import scala.language.existentials

case class Value(v: Any, name: String, givenTypeInfo: Option[TypeInformation[_]] = None) {
  require(v != null || givenTypeInfo.isDefined, "You must pass a TypeInformation for null values")

  val typeInfo = givenTypeInfo match {
    case Some(ti) => ti
    case None => TypeExtractor.getForObject(v)
  }

  require(isAssignable(v, typeInfo.getTypeClass), s"data element '$v' is not compatible with class ${typeInfo.getTypeClass.getName}")

  private def isAssignable(value: Any, cl: Class[_]) = {
    if (value == null && classOf[AnyRef].isAssignableFrom(cl))
      true
    else
      ClassUtils.isAssignable(value.getClass, cl)
  }
}

object Value {
  def apply(v: Any, name: String, givenTypeInfo: TypeInformation[_]) = {
    new Value(v, name, Some(givenTypeInfo))
  }
}



class DataRow(private [connectors] val data: Array[Any], private [connectors] val info: DataRowTypeInfo) extends Product with Serializable {
  require(data != null, "data must not be null")
  require(info != null, "info must not be null")
  require(data.length == info.getArity, "data must be of the correct arity")

  def apply[T](i: Int): T =
    data(i).asInstanceOf[T]

  def apply[T](fieldExpression: String): T =
    apply(info.getFieldIndex(fieldExpression))

  override def productElement(n: Int): Any =
    apply[AnyRef](n)

  override def productArity =
    info.getArity

  override def canEqual(that: Any) =
    that.isInstanceOf[DataRow]

  override def equals(that: Any) =
    canEqual(that) && data.sameElements(that.asInstanceOf[DataRow].data) && info.getFieldNames.sameElements(that.asInstanceOf[DataRow].info.getFieldNames)

  override def hashCode = {
    var result = 1

    for (element <- data)
      result = 31 * result + (if (element == null) 0 else element.hashCode)

    result
  }

  override def toString =
    info.getFieldNames
      .zip(data.map(v => if (v == null) "null" else v.toString))
      .map{case (name, value) => s"$name=$value"}
      .mkString("DataRow(", ", ", ")")
}

object DataRow {
  
  def apply(data: Value*): DataRow = {
    require(data != null, "data cannot be null")
    require(!data.contains(null), "data value cannot be null")

    new DataRow(
      data.map(_.v).toArray,
      new DataRowTypeInfo(
        data.map(_.name),
        data.map(_.typeInfo)
      )
    )
  }
} 
Example 32
Source File: RecordTransformer.scala    From flink-elasticsearch-source-connector   with Apache License 2.0 5 votes vote down vote up
package com.mnubo.flink.streaming.connectors

import org.apache.flink.api.common.operators.Keys.ExpressionKeys._
import org.apache.flink.api.common.typeinfo.TypeInformation

import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag

sealed trait FieldSpecification extends Serializable

case class ExistingField(name: String) extends FieldSpecification

case class NewField(name: String, typeInfo: TypeInformation[_]) extends FieldSpecification

trait RecordTransformer extends Serializable {
  val classTag = ClassTag[DataRow](classOf[DataRow])
  def typeInfo : DataRowTypeInfo
  def transform(dataRow: DataRow, values:Any*) : DataRow
}

class FieldMapperRecordTransformer private[connectors](srcTypeInfo:DataRowTypeInfo, fieldSpecifications: FieldSpecification*) extends RecordTransformer {
  require(srcTypeInfo != null, s"srcTypeInfo must not be null")
  require(fieldSpecifications != null, s"fieldSpecifications must not be null")
  require(fieldSpecifications.nonEmpty, s"fieldSpecifications must not be empty")
  require(!fieldSpecifications.contains(null), s"fieldSpecifications must not contain any nulls")

  override val typeInfo = {
    val (fieldNames, elementTypes) = fieldSpecifications.flatMap {
      case ExistingField(name) if name == SELECT_ALL_CHAR || name == SELECT_ALL_CHAR_SCALA => srcTypeInfo.getFieldNames.zip(srcTypeInfo.getElementTypes)
      case ExistingField(name) => Seq(name -> srcTypeInfo.getFieldType(name))
      case NewField(name, newFieldTypeInfo) => Seq(name -> newFieldTypeInfo)
    }.unzip
    require(fieldNames.length == fieldNames.distinct.length, s"Fields can't have duplicates. Fields were $fieldNames.")
    new DataRowTypeInfo(fieldNames, elementTypes)
  }

  private def newFieldsNames = fieldSpecifications.collect{ case newValue: NewField => newValue.name }

  override def transform(dataRow: DataRow, values:Any*) : DataRow = {
    require(dataRow != null, s"dataRow must not be null")
    require(values != null, s"values must not be null")
    require(newFieldsNames.length == values.length, s"Must specify values for all new fields and only new fields. New fields are '$newFieldsNames'")

    val resultValues = new Array[Any](typeInfo.getArity)
    @tailrec
    def transform(index:Int, remainingSpecs: Seq[FieldSpecification], remainingValues:Seq[Any]) : DataRow = {
      if(remainingSpecs.isEmpty) {
        new DataRow(resultValues, typeInfo)
      } else {
        val currentSpec = remainingSpecs.head
        currentSpec match {
          case ExistingField(name) if name == SELECT_ALL_CHAR || name == SELECT_ALL_CHAR_SCALA =>
            Array.copy(dataRow.data, 0, resultValues, index, dataRow.data.length)
            transform(index + dataRow.data.length, remainingSpecs.tail, remainingValues)
          case ExistingField(name) =>
            resultValues(index) = dataRow(name)
            transform(index + 1, remainingSpecs.tail, remainingValues)
          case NewField(name, _) =>
            resultValues(index) = remainingValues.head
            transform(index + 1, remainingSpecs.tail, remainingValues.tail)
        }
      }
    }
    transform(0, fieldSpecifications, values)
  }
}

object RecordTransformer {
  def mapFields(srcTypeInfo: DataRowTypeInfo, fieldSpecifications: FieldSpecification*) : RecordTransformer = {
    new FieldMapperRecordTransformer(srcTypeInfo, fieldSpecifications:_*)
  }
} 
Example 33
Source File: DAGSchedulerEvent.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.Properties

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{AccumulatorV2, CallSite}


private[scheduler] case class MapStageSubmitted(
  jobId: Int,
  dependency: ShuffleDependency[_, _, _],
  callSite: CallSite,
  listener: JobListener,
  properties: Properties = null)
  extends DAGSchedulerEvent

private[scheduler] case class StageCancelled(
    stageId: Int,
    reason: Option[String])
  extends DAGSchedulerEvent

private[scheduler] case class JobCancelled(
    jobId: Int,
    reason: Option[String])
  extends DAGSchedulerEvent

private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent

private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent

private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent

private[scheduler]
case class GettingResultEvent(taskInfo: TaskInfo) extends DAGSchedulerEvent

private[scheduler] case class CompletionEvent(
    task: Task[_],
    reason: TaskEndReason,
    result: Any,
    accumUpdates: Seq[AccumulatorV2[_, _]],
    taskInfo: TaskInfo)
  extends DAGSchedulerEvent

private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent

private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason)
  extends DAGSchedulerEvent

private[scheduler] case class WorkerRemoved(workerId: String, host: String, message: String)
  extends DAGSchedulerEvent

private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
  extends DAGSchedulerEvent

private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent

private[scheduler]
case class SpeculativeTaskSubmitted(task: Task[_]) extends DAGSchedulerEvent 
Example 34
Source File: NettyBlockRpcServer.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.NioManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}


class NettyBlockRpcServer(
    appId: String,
    serializer: Serializer,
    blockManager: BlockDataManager)
  extends RpcHandler with Logging {

  private val streamManager = new OneForOneStreamManager()

  override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocksNum = openBlocks.blockIds.length
        val blocks = for (i <- (0 until blocksNum).view)
          yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with $blocksNum buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)

      case uploadBlock: UploadBlock =>
        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
        val (level: StorageLevel, classTag: ClassTag[_]) = {
          serializer
            .newInstance()
            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
            .asInstanceOf[(StorageLevel, ClassTag[_])]
        }
        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
        val blockId = BlockId(uploadBlock.blockId)
        blockManager.putBlockData(blockId, data, level, classTag)
        responseContext.onSuccess(ByteBuffer.allocate(0))
    }
  }

  override def getStreamManager(): StreamManager = streamManager
} 
Example 35
Source File: InsertIntoHiveDirCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import scala.language.existentials

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.FileUtils
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.mapred._

import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.client.HiveClientImpl


case class InsertIntoHiveDirCommand(
    isLocal: Boolean,
    storage: CatalogStorageFormat,
    query: LogicalPlan,
    overwrite: Boolean,
    outputColumns: Seq[Attribute]) extends SaveAsHiveFile {

  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
    assert(storage.locationUri.nonEmpty)

    val hiveTable = HiveClientImpl.toHiveTable(CatalogTable(
      identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")),
      tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW,
      storage = storage,
      schema = query.schema
    ))
    hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB,
      storage.serde.getOrElse(classOf[LazySimpleSerDe].getName))

    val tableDesc = new TableDesc(
      hiveTable.getInputFormatClass,
      hiveTable.getOutputFormatClass,
      hiveTable.getMetadata
    )

    val hadoopConf = sparkSession.sessionState.newHadoopConf()
    val jobConf = new JobConf(hadoopConf)

    val targetPath = new Path(storage.locationUri.get)
    val writeToPath =
      if (isLocal) {
        val localFileSystem = FileSystem.getLocal(jobConf)
        localFileSystem.makeQualified(targetPath)
      } else {
        val qualifiedPath = FileUtils.makeQualified(targetPath, hadoopConf)
        val dfs = qualifiedPath.getFileSystem(jobConf)
        if (!dfs.exists(qualifiedPath)) {
          dfs.mkdirs(qualifiedPath.getParent)
        }
        qualifiedPath
      }

    val tmpPath = getExternalTmpPath(sparkSession, hadoopConf, writeToPath)
    val fileSinkConf = new org.apache.spark.sql.hive.HiveShim.ShimFileSinkDesc(
      tmpPath.toString, tableDesc, false)

    try {
      saveAsHiveFile(
        sparkSession = sparkSession,
        plan = child,
        hadoopConf = hadoopConf,
        fileSinkConf = fileSinkConf,
        outputLocation = tmpPath.toString,
        allColumns = outputColumns)

      val fs = writeToPath.getFileSystem(hadoopConf)
      if (overwrite && fs.exists(writeToPath)) {
        fs.listStatus(writeToPath).foreach { existFile =>
          if (Option(existFile.getPath) != createdTempDir) fs.delete(existFile.getPath, true)
        }
      }

      fs.listStatus(tmpPath).foreach {
        tmpFile => fs.rename(tmpFile.getPath, writeToPath)
      }
    } catch {
      case e: Throwable =>
        throw new SparkException(
          "Failed inserting overwrite directory " + storage.locationUri.get, e)
    } finally {
      deleteExternalTmpPath(hadoopConf)
    }

    Seq.empty[Row]
  }
} 
Example 36
Source File: FPTreeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 37
Source File: HadoopUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.image

import scala.language.existentials
import scala.util.Random

import org.apache.commons.io.FilenameUtils
import org.apache.hadoop.conf.{Configuration, Configured}
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat

import org.apache.spark.sql.SparkSession

private object RecursiveFlag {
  
  def withPathFilter[T](
      sampleRatio: Double,
      spark: SparkSession,
      seed: Long)(f: => T): T = {
    val sampleImages = sampleRatio < 1
    if (sampleImages) {
      val flagName = FileInputFormat.PATHFILTER_CLASS
      val hadoopConf = spark.sparkContext.hadoopConfiguration
      val old = Option(hadoopConf.getClass(flagName, null))
      hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio)
      hadoopConf.setLong(SamplePathFilter.seedParam, seed)
      hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter])
      try f finally {
        hadoopConf.unset(SamplePathFilter.ratioParam)
        hadoopConf.unset(SamplePathFilter.seedParam)
        old match {
          case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter])
          case None => hadoopConf.unset(flagName)
        }
      }
    } else {
      f
    }
  }
} 
Example 38
Source File: IStep.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.core.step

import org.apache.s2graph.core._
import rx.lang.scala.Observable

import scala.language.higherKinds
import scala.language.existentials

trait RxStep[-A, +B] extends (A => Observable[B])

object RxStep {

  case class VertexFetchStep(g: S2GraphLike) extends RxStep[Seq[S2VertexLike], S2VertexLike] {
    override def apply(vertices: Seq[S2VertexLike]): Observable[S2VertexLike] = {
      Observable.from(vertices)
    }
  }

  case class EdgeFetchStep(g: S2GraphLike, qp: QueryParam) extends RxStep[S2VertexLike, S2EdgeLike] {
    override def apply(v: S2VertexLike): Observable[S2EdgeLike] = {
      implicit val ec = g.ec

      val step = org.apache.s2graph.core.Step(Seq(qp))
      val q = Query(Seq(v), steps = Vector(step))

      val f = g.getEdges(q).map { stepResult =>
        val edges = stepResult.edgeWithScores.map(_.edge)
        Observable.from(edges)
      }

      Observable.from(f).flatten
    }
  }

  private def merge[A, B](steps: RxStep[A, B]*): RxStep[A, B] = new RxStep[A, B] {
    override def apply(in: A): Observable[B] =
      steps.map(_.apply(in)).toObservable.flatten
  }

  def toObservable(q: Query)(implicit graph: S2GraphLike): Observable[S2EdgeLike] = {
    val v1: Observable[S2VertexLike] = VertexFetchStep(graph).apply(q.vertices)

    val serialSteps = q.steps.map { step =>
      val parallelSteps = step.queryParams.map(qp => EdgeFetchStep(graph, qp))
      merge(parallelSteps: _*)
    }

    v1.flatMap { v =>
      val initOpt = serialSteps.headOption.map(_.apply(v))

      initOpt.map { init =>
        serialSteps.tail.foldLeft(init) { case (prev, next) =>
          prev.map(_.tgtForVertex).flatMap(next)
        }
      }.getOrElse(Observable.empty)
    }
  }
} 
Example 39
Source File: ArrayBasedMapData.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.{Map => JavaMap}

class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData {
  require(keyArray.numElements() == valueArray.numElements())

  override def numElements(): Int = keyArray.numElements()

  override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())

  override def toString: String = {
    s"keys: $keyArray, values: $valueArray"
  }
}

object ArrayBasedMapData {
  
  def apply(
      iterator: Iterator[(_, _)],
      size: Int,
      keyConverter: (Any) => Any,
      valueConverter: (Any) => Any): ArrayBasedMapData = {

    val keys: Array[Any] = new Array[Any](size)
    val values: Array[Any] = new Array[Any](size)

    var i = 0
    for ((key, value) <- iterator) {
      keys(i) = keyConverter(key)
      values(i) = valueConverter(value)
      i += 1
    }
    ArrayBasedMapData(keys, values)
  }

  def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = {
    new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
  }

  def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
    val keys = map.keyArray.asInstanceOf[GenericArrayData].array
    val values = map.valueArray.asInstanceOf[GenericArrayData].array
    keys.zip(values).toMap
  }

  def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
    keys.zip(values).toMap
  }

  def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
    keys.zip(values).toMap
  }

  def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
    import scala.collection.JavaConverters._
    keys.zip(values).toMap.asJava
  }
} 
Example 40
Source File: VLinearRegressionSuite.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.regression

import scala.language.existentials
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame


class VLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

  import testImplicits._
  var datasetWithWeight: DataFrame = _

  override def beforeAll(): Unit = {
    super.beforeAll()

    datasetWithWeight = sc.parallelize(Seq(
      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
      Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
      Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
      Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
    ), 2).toDF()
  }

  test("test on datasetWithWeight") {

    def b2s(b: Boolean): String = {
      if (b) "w/" else "w/o"
    }

    for (fitIntercept <- Seq(false, true)) {
      for (standardization <- Seq(false, true)) {
        for ((reg, elasticNet)<- Seq((0.0, 0.0), (2.3, 0.0), (2.3, 0.5))) {

          println()
          println(s"# test ${b2s(fitIntercept)} intercept, ${b2s(standardization)} standardization, reg=${reg}, elasticNet=${elasticNet}")

          val vtrainer = new VLinearRegression()
            .setColsPerBlock(1)
            .setRowsPerBlock(1)
            .setGeneratingFeatureMatrixBuffer(2)
            .setFitIntercept(fitIntercept)
            .setStandardization(standardization)
            .setRegParam(reg)
            .setWeightCol("weight")
            .setElasticNetParam(elasticNet)
          val vmodel = vtrainer.fit(datasetWithWeight)

          // Note that in ml.LinearRegression, when datasets numInstanse is small
          // solver l-bfgs and solver normal will generate slightly different result when reg not zero
          // because there std calculation result have multiple difference numInstance/(numInstance - 1)
          // here test keep consistent with l-bfgs solver
          val trainer = new LinearRegression()
            .setSolver("l-bfgs") // by default it may use noraml solver so here force set it.
            .setFitIntercept(fitIntercept)
            .setStandardization(standardization)
            .setRegParam(reg)
            .setWeightCol("weight")
            .setElasticNetParam(elasticNet)

          val model = trainer.fit(datasetWithWeight)
          logInfo(s"LinearRegression total iterations: ${model.summary.totalIterations}")

          println(s"VLinearRegression coefficients: ${vmodel.coefficients.toDense}, intercept: ${vmodel.intercept}\n" +
            s"LinearRegression coefficients: ${model.coefficients.toDense}, intercept: ${model.intercept}")

          def filterSmallValue(v: Vector) = {
            Vectors.dense(v.toArray.map(x => if (math.abs(x) < 1e-6) 0.0 else x))
          }
          assert(filterSmallValue(vmodel.coefficients) ~== filterSmallValue(model.coefficients) relTol 1e-3)
          assert(vmodel.intercept ~== model.intercept relTol 1e-3)
        }
      }
    }
  }
} 
Example 41
Source File: VSoftmaxRegressionSuite.scala    From spark-vlbfgs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.classification

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{SparseMatrix, Vector, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset}

import scala.language.existentials


class VSoftmaxRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {

  import testImplicits._

  private val seed = 42
  @transient var multinomialDataset: Dataset[_] = _
  private val eps: Double = 1e-5

  override def beforeAll(): Unit = {
    super.beforeAll()

    multinomialDataset = {
      val nPoints = 50
      val coefficients = Array(
        -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
        -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)

      val xMean = Array(5.843, 3.057, 3.758, 1.199)
      val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)

      val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
        coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)

      val df = sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed))
      df.cache()
      println("softmax test data:")
      df.show(10, false)
      df
    }
  }

  test("test on multinomialDataset") {

    def b2s(b: Boolean): String = {
      if (b) "w/" else "w/o"
    }

    for (standardization <- Seq(false, true)) {
      for ((reg, elasticNet) <- Seq((0.0, 0.0), (2.3, 0.0), (0.3, 0.05), (0.01, 1.0))) {
        println()
        println(s"# test ${b2s(standardization)} standardization, reg=${reg}, elasticNet=${elasticNet}")

        val trainer = new LogisticRegression()
          .setFamily("multinomial")
          .setStandardization(standardization)
          .setWeightCol("weight")
          .setRegParam(reg)
          .setFitIntercept(false)
          .setElasticNetParam(elasticNet)

        val model = trainer.fit(multinomialDataset)

        val vtrainer = new VSoftmaxRegression()
          .setColsPerBlock(2)
          .setRowsPerBlock(5)
          .setColPartitions(2)
          .setRowPartitions(3)
          .setWeightCol("weight")
          .setGeneratingFeatureMatrixBuffer(2)
          .setStandardization(standardization)
          .setRegParam(reg)
          .setElasticNetParam(elasticNet)
        val vmodel = vtrainer.fit(multinomialDataset)

        println(s"VSoftmaxRegression coefficientMatrix:\n" +
          s"${vmodel.coefficientMatrix.asInstanceOf[SparseMatrix].toDense},\n" +
          s"ml.SoftmaxRegression coefficientMatrix:\n" +
          s"${model.coefficientMatrix}\n")

        assert(vmodel.coefficientMatrix ~== model.coefficientMatrix relTol eps)
      }
    }
  }
} 
Example 42
Source File: SortedMapDeserializerModule.scala    From mango   with Apache License 2.0 5 votes vote down vote up
package com.kakao.shaded.jackson.module.scala.deser

import java.util.AbstractMap
import java.util.Map.Entry

import scala.collection.{mutable, SortedMap}
import scala.collection.immutable.TreeMap

import com.kakao.shaded.jackson.core.JsonParser
import com.kakao.shaded.jackson.databind._
import com.kakao.shaded.jackson.databind.deser.std.{MapDeserializer, ContainerDeserializerBase}
import com.kakao.shaded.jackson.databind.jsontype.TypeDeserializer
import com.kakao.shaded.jackson.databind.`type`.MapLikeType
import com.kakao.shaded.jackson.module.scala.modifiers.MapTypeModifierModule
import deser.{ContextualDeserializer, Deserializers, ValueInstantiator}
import com.kakao.shaded.jackson.module.scala.introspect.OrderingLocator
import scala.language.existentials

private class SortedMapBuilderWrapper[K,V](val builder: mutable.Builder[(K,V), SortedMap[K,V]]) extends AbstractMap[K,V] {
  override def put(k: K, v: V) = { builder += ((k,v)); v }

  // Isn't used by the deserializer
  def entrySet(): java.util.Set[Entry[K, V]] = throw new UnsupportedOperationException
}

private object SortedMapDeserializer {
  def orderingFor = OrderingLocator.locate _

  def builderFor(cls: Class[_], keyCls: JavaType): mutable.Builder[(AnyRef,AnyRef), SortedMap[AnyRef,AnyRef]] =
    if (classOf[TreeMap[_,_]].isAssignableFrom(cls)) TreeMap.newBuilder[AnyRef,AnyRef](orderingFor(keyCls)) else
    SortedMap.newBuilder[AnyRef,AnyRef](orderingFor(keyCls))
}

private class SortedMapDeserializer(
    collectionType: MapLikeType,
    config: DeserializationConfig,
    keyDeser: KeyDeserializer,
    valueDeser: JsonDeserializer[_],
    valueTypeDeser: TypeDeserializer)
  extends ContainerDeserializerBase[SortedMap[_,_]](collectionType)
  with ContextualDeserializer {
  
  private val javaContainerType =
    config.getTypeFactory.constructMapLikeType(classOf[MapBuilderWrapper[_,_]], collectionType.getKeyType, collectionType.getContentType)

  private val instantiator =
    new ValueInstantiator {
      def getValueTypeDesc = collectionType.getRawClass.getCanonicalName
      override def canCreateUsingDefault = true
      override def createUsingDefault(ctx: DeserializationContext) =
        new SortedMapBuilderWrapper[AnyRef,AnyRef](SortedMapDeserializer.builderFor(collectionType.getRawClass, collectionType.getKeyType))
    }

  private val containerDeserializer =
    new MapDeserializer(javaContainerType,instantiator,keyDeser,valueDeser.asInstanceOf[JsonDeserializer[AnyRef]],valueTypeDeser)

  override def getContentType = containerDeserializer.getContentType

  override def getContentDeserializer = containerDeserializer.getContentDeserializer

  override def createContextual(ctxt: DeserializationContext, property: BeanProperty) =
    if (keyDeser != null && valueDeser != null) this
    else {
      val newKeyDeser = Option(keyDeser).getOrElse(ctxt.findKeyDeserializer(collectionType.getKeyType, property))
      val newValDeser = Option(valueDeser).getOrElse(ctxt.findContextualValueDeserializer(collectionType.getContentType, property))
      new SortedMapDeserializer(collectionType, config, newKeyDeser, newValDeser, valueTypeDeser)
    }
  
  override def deserialize(jp: JsonParser, ctxt: DeserializationContext): SortedMap[_,_] = {
    containerDeserializer.deserialize(jp,ctxt) match {
      case wrapper: SortedMapBuilderWrapper[_,_] => wrapper.builder.result()
    }
  }
}

private object SortedMapDeserializerResolver extends Deserializers.Base {
  
  private val SORTED_MAP = classOf[collection.SortedMap[_,_]]

  override def findMapLikeDeserializer(theType: MapLikeType,
                              config: DeserializationConfig,
                              beanDesc: BeanDescription,
                              keyDeserializer: KeyDeserializer,
                              elementTypeDeserializer: TypeDeserializer,
                              elementDeserializer: JsonDeserializer[_]): JsonDeserializer[_] =
    if (!SORTED_MAP.isAssignableFrom(theType.getRawClass)) null
    else new SortedMapDeserializer(theType,config,keyDeserializer,elementDeserializer,elementTypeDeserializer)
}


trait SortedMapDeserializerModule extends MapTypeModifierModule {
  this += (_ addDeserializers SortedMapDeserializerResolver)
} 
Example 43
Source File: PropertyDescriptor.scala    From mango   with Apache License 2.0 5 votes vote down vote up
package com.kakao.shaded.jackson.module.scala
package introspect

import util.Implicits._
import java.lang.reflect.{AccessibleObject, Constructor, Field, Method}

import scala.language.existentials

case class ConstructorParameter(constructor: Constructor[_], index: Int, defaultValueMethod: Option[Method])

case class PropertyDescriptor(name: String,
                              param: Option[ConstructorParameter],
                              field: Option[Field],
                              getter: Option[Method],
                              setter: Option[Method],
                              beanGetter: Option[Method],
                              beanSetter: Option[Method])
{
  if (List(field, getter).flatten.isEmpty) throw new IllegalArgumentException("One of field or getter must be defined.")

  def findAnnotation[A <: java.lang.annotation.Annotation](implicit mf: Manifest[A]): Option[A] = {
    val cls = mf.runtimeClass.asInstanceOf[Class[A]]
    lazy val paramAnnotation = (param flatMap { cp =>
      val paramAnnos = cp.constructor.getParameterAnnotations
      paramAnnos(cp.index).find(cls.isInstance)
    }).asInstanceOf[Option[A]]
    val getAnno = (o: AccessibleObject) => o.getAnnotation(cls)
    lazy val fieldAnnotation = field optMap getAnno
    lazy val getterAnnotation = getter optMap getAnno
    lazy val beanGetterAnnotation = beanGetter optMap getAnno

    paramAnnotation orElse fieldAnnotation orElse getterAnnotation orElse beanGetterAnnotation
  }

} 
Example 44
Source File: ReplicationFilterSerializer.scala    From eventuate   with Apache License 2.0 5 votes vote down vote up
package com.rbmhtechnology.eventuate.serializer

import akka.actor.ExtendedActorSystem
import akka.serialization._

import com.rbmhtechnology.eventuate.ReplicationFilter.AndFilter
import com.rbmhtechnology.eventuate.ReplicationFilter.NoFilter
import com.rbmhtechnology.eventuate.ReplicationFilter.OrFilter

import com.rbmhtechnology.eventuate._
import com.rbmhtechnology.eventuate.serializer.ReplicationFilterFormats._

import scala.collection.JavaConverters._
import scala.language.existentials

class ReplicationFilterSerializer(system: ExtendedActorSystem) extends Serializer {
  import ReplicationFilterTreeFormat.NodeType._

  val payloadSerializer = new DelegatingPayloadSerializer(system)

  val AndFilterClass = classOf[AndFilter]
  val OrFilterClass = classOf[OrFilter]
  val NoFilterClass = NoFilter.getClass

  override def identifier: Int = 22564
  override def includeManifest: Boolean = true

  override def toBinary(o: AnyRef): Array[Byte] = o match {
    case NoFilter =>
      NoFilterFormat.newBuilder().build().toByteArray
    case f: ReplicationFilter =>
      filterTreeFormatBuilder(f).build().toByteArray
    case _ =>
      throw new IllegalArgumentException(s"can't serialize object of type ${o.getClass}")
  }

  override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = manifest match {
    case None => throw new IllegalArgumentException("manifest required")
    case Some(clazz) => clazz match {
      case NoFilterClass =>
        NoFilter
      case AndFilterClass | OrFilterClass =>
        filterTree(ReplicationFilterTreeFormat.parseFrom(bytes))
      case _ =>
        throw new IllegalArgumentException(s"can't deserialize object of type ${clazz}")
    }
  }

  // --------------------------------------------------------------------------------
  //  toBinary helpers
  // --------------------------------------------------------------------------------

  def filterTreeFormatBuilder(filterTree: ReplicationFilter): ReplicationFilterTreeFormat.Builder = {
    val builder = ReplicationFilterTreeFormat.newBuilder()
    filterTree match {
      case AndFilter(filters) =>
        builder.setNodeType(AND)
        filters.foreach(filter => builder.addChildren(filterTreeFormatBuilder(filter)))
      case OrFilter(filters) =>
        builder.setNodeType(OR)
        filters.foreach(filter => builder.addChildren(filterTreeFormatBuilder(filter)))
      case filter =>
        builder.setNodeType(LEAF)
        builder.setFilter(payloadSerializer.payloadFormatBuilder(filter))
    }
    builder
  }

  // --------------------------------------------------------------------------------
  //  fromBinary helpers
  // --------------------------------------------------------------------------------

  def filterTree(filterTreeFormat: ReplicationFilterTreeFormat): ReplicationFilter = {
    filterTreeFormat.getNodeType match {
      case AND  => AndFilter(filterTreeFormat.getChildrenList.asScala.map(filterTree).toList)
      case OR   => OrFilter(filterTreeFormat.getChildrenList.asScala.map(filterTree).toList)
      case LEAF => payloadSerializer.payload(filterTreeFormat.getFilter).asInstanceOf[ReplicationFilter]
    }
  }
} 
Example 45
Source File: BytecodeUtils.scala    From graphx-algorithm   with GNU General Public License v2.0 5 votes vote down vote up
package org.apache.spark.graphx.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.HashSet
import scala.language.existentials

import org.apache.spark.util.Utils

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._



  private class MethodInvocationFinder(className: String, methodName: String)
    extends ClassVisitor(ASM4) {

    val methodsInvoked = new HashSet[(Class[_], String)]

    override def visitMethod(access: Int, name: String, desc: String,
                             sig: String, exceptions: Array[String]): MethodVisitor = {
      if (name == methodName) {
        new MethodVisitor(ASM4) {
          override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
            if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
              if (!skipClass(owner)) {
                methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
              }
            }
          }
        }
      } else {
        null
      }
    }
  }
} 
Example 46
Source File: LoggingState.scala    From logging   with Apache License 2.0 5 votes vote down vote up
package com.persist.logging

import akka.actor._
import LogActor.{AkkaMessage, LogActorMessage}
import scala.language.existentials
import scala.concurrent.Promise
import scala.collection.mutable
import TimeActorMessages._


private[logging] object LoggingState extends ClassLogging {

  // Queue of messages sent before logger is started
  private[logging] val msgs = new mutable.Queue[LogActorMessage]()

  @volatile var doTrace:Boolean = false
  @volatile var doDebug: Boolean = false
  @volatile var doInfo: Boolean = true
  @volatile var doWarn: Boolean = true
  @volatile var doError: Boolean = true

  private[logging] var loggingSys: LoggingSystem = null

  private[logging] var logger: Option[ActorRef] = None
  @volatile private[logging] var loggerStopping = false

  private[logging] var doTime: Boolean = false
  private[logging] var timeActorOption: Option[ActorRef] = None


  // Use to sync akka logging actor shutdown
  private[logging] val akkaStopPromise = Promise[Unit]

  private[logging] def sendMsg(msg: LogActorMessage) {
    if (loggerStopping) {
      println(s"*** Log message received after logger shutdown: $msg")
    } else {
      logger match {
        case Some(a) =>
          a ! msg
        case None =>
          msgs.synchronized {
            msgs.enqueue(msg)
          }
      }
    }
  }

  private[logging] def akkaMsg(m: AkkaMessage) {
    if (m.msg == "DIE") {
      akkaStopPromise.trySuccess(())
    } else {
      sendMsg(m)
    }
  }

  private[logging] def timeStart(id: RequestId, name: String, uid: String) {
    timeActorOption foreach {
      case timeActor =>
        val time = System.nanoTime() / 1000
        timeActor ! TimeStart(id, name, uid, time)
    }
  }

  private[logging] def timeEnd(id: RequestId, name: String, uid: String) {
    timeActorOption foreach {
      case timeActor =>
        val time = System.nanoTime() / 1000
        timeActor ! TimeEnd(id, name, uid, time)
    }
  }
} 
Example 47
Source File: Query.scala    From finagle-postgres   with Apache License 2.0 5 votes vote down vote up
package com.twitter.finagle.postgres.generic

import com.twitter.concurrent.AsyncStream

import scala.collection.immutable.Queue
import com.twitter.finagle.postgres.{Param, PostgresClient, Row}
import com.twitter.util.Future

import scala.language.existentials

case class Query[T](parts: Seq[String], queryParams: Seq[QueryParam], cont: Row => T) {

  def stream(client: PostgresClient): AsyncStream[T] = {
    val (queryString, params) = impl
    client.prepareAndQueryToStream[T](queryString, params: _*)(cont)
  }

  def run(client: PostgresClient): Future[Seq[T]] =
    stream(client).toSeq

  def exec(client: PostgresClient): Future[Int] = {
    val (queryString, params) = impl
    client.prepareAndExecute(queryString, params: _*)
  }

  def map[U](fn: T => U): Query[U] = copy(cont = cont andThen fn)

  def as[U](implicit rowDecoder: RowDecoder[U], columnNamer: ColumnNamer): Query[U] = {
    copy(cont = row => rowDecoder(row)(columnNamer))
  }

  private def impl: (String, Seq[Param[_]]) = {
    val (last, placeholders, params) = queryParams.foldLeft((1, Queue.empty[Seq[String]], Queue.empty[Param[_]])) {
      case ((start, placeholders, params), next) =>
        val nextPlaceholders = next.placeholders(start)
        val nextParams = Queue(next.params: _*)
        (start + nextParams.length, placeholders enqueue nextPlaceholders, params ++ nextParams)
    }

    val queryString = parts.zipAll(placeholders, "", Seq.empty).flatMap {
      case (part, ph) => Seq(part, ph.mkString(", "))
    }.mkString

    (queryString, params)
  }


}

object Query {
  implicit class RowQueryOps(val self: Query[Row]) extends AnyVal {
    def ++(that: Query[Row]): Query[Row] = Query[Row](
      parts = if(self.parts.length > self.queryParams.length)
        (self.parts.dropRight(1) :+ (self.parts.lastOption.getOrElse("") + that.parts.headOption.getOrElse(""))) ++ that.parts.drop(1)
      else
        self.parts ++ that.parts,
      queryParams = self.queryParams ++ that.queryParams,
      cont = self.cont
    )

    def ++(that: String): Query[Row] = Query[Row](
      parts = if(self.parts.length > self.queryParams.length)
          self.parts.dropRight(1) :+ (self.parts.last + that)
        else
          self.parts :+ that,
      queryParams = self.queryParams,
      cont = self.cont
    )
  }
} 
Example 48
Source File: InferShape.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.abstractnn

import com.intel.analytics.bigdl.nn.keras.{Input => KInput, Sequential => KSequential}
import com.intel.analytics.bigdl.nn.{Input => TInput}
import com.intel.analytics.bigdl.utils.Shape

import scala.language.existentials
import scala.reflect.ClassTag

class InvalidLayer(msg: String) extends RuntimeException(msg)

trait InferShape {
  private[bigdl] var _inputShapeValue: Shape = null

  private[bigdl] var _outputShapeValue: Shape = null

  private[bigdl] def inputShapeValue: Shape = _inputShapeValue

  private[bigdl] def outputShapeValue: Shape = _outputShapeValue

  // scalastyle:off
  private[bigdl] def inputShapeValue_=(value: Shape): Unit = {
    _inputShapeValue = value
  }

  private[bigdl] def outputShapeValue_=(value: Shape): Unit = {
    _outputShapeValue = value
  }
  // scalastyle:on

  
  private[bigdl] def computeOutputShape(inputShape: Shape): Shape = {
    throw new RuntimeException("Haven't been implemented yet. Do not use it with Keras Layer")
  }

  private[bigdl] def excludeInvalidLayers[T: ClassTag]
  (modules : Seq[AbstractModule[_, _, T]]): Unit = {
    val invalidNodes = if (this.isKerasStyle()) {
      modules.filter{!_.isKerasStyle()}
    } else {
      modules.filter{_.isKerasStyle()}
    }
    if (invalidNodes.length > 0) {
      throw new InvalidLayer(s"""Do not mix ${this}(isKerasStyle=${isKerasStyle()}) with Layer
                           (isKerasStyle=${invalidNodes(0).isKerasStyle()}):
         ${invalidNodes.mkString(",")}""")
    }
  }

  private[bigdl] def validateInput[T: ClassTag](modules : Seq[AbstractModule[_, _, T]]): Unit = {
    if (this.isKerasStyle()) {
      require(modules != null && !modules.isEmpty, "Empty input is not allowed")
    }
    excludeInvalidLayers(modules)
  }
} 
Example 49
Source File: PythonBigDLValidator.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.python.api

import java.lang.{Boolean => JBoolean}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.collection.JavaConverters._
import scala.collection.mutable.Map
import scala.language.existentials
import scala.reflect.ClassTag

object PythonBigDLValidator {

  def ofFloat(): PythonBigDLValidator[Float] = new PythonBigDLValidator[Float]()

  def ofDouble(): PythonBigDLValidator[Double] = new PythonBigDLValidator[Double]()
}

class PythonBigDLValidator[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonBigDL[T]{

  def testDict(): JMap[String, String] = {
    return Map("jack" -> "40", "lucy" -> "50").asJava
  }

  def testDictJTensor(): JMap[String, JTensor] = {
    return Map("jack" -> JTensor(Array(1.0f, 2.0f, 3.0f, 4.0f), Array(4, 1), "float")).asJava
  }

  def testDictJMapJTensor(): JMap[String, JMap[String, JTensor]] = {
    val table = new Table()
    val tensor = JTensor(Array(1.0f, 2.0f, 3.0f, 4.0f), Array(4, 1), "float")
    val result = Map("jack" -> tensor).asJava
    table.insert(tensor)
    return Map("nested" -> result).asJava
  }

  def testActivityWithTensor(): JActivity = {
    val tensor = Tensor(Array(1.0f, 2.0f, 3.0f, 4.0f), Array(4, 1))
    return JActivity(tensor)
  }

  def testActivityWithTableOfTensor(): JActivity = {
    val tensor1 = Tensor(Array(1.0f, 1.0f), Array(2))
    val tensor2 = Tensor(Array(2.0f, 2.0f), Array(2))
    val tensor3 = Tensor(Array(3.0f, 3.0f), Array(2))
    val table = new Table()
    table.insert(tensor1)
    table.insert(tensor2)
    table.insert(tensor3)
    return JActivity(table)
  }

  def testActivityWithTableOfTable(): JActivity = {
    val tensor = Tensor(Array(1.0f, 2.0f, 3.0f, 4.0f), Array(4, 1))
    val table = new Table()
    table.insert(tensor)
    val nestedTable = new Table()
    nestedTable.insert(table)
    nestedTable.insert(table)
    return JActivity(nestedTable)
  }
} 
Example 50
Source File: TreeSentiment.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.treeLSTMSentiment

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.nn._
import com.intel.analytics.bigdl.numeric.NumericFloat
import com.intel.analytics.bigdl.tensor.Tensor

import scala.language.existentials

object TreeLSTMSentiment {
  def apply(
    word2VecTensor: Tensor[Float],
    hiddenSize: Int,
    classNum: Int,
    p: Double = 0.5
  ): Module[Float] = {
    val vocabSize = word2VecTensor.size(1)
    val embeddingDim = word2VecTensor.size(2)
    val embedding = LookupTable(vocabSize, embeddingDim)
    embedding.weight.set(word2VecTensor)
    embedding.setScaleW(2)

    val treeLSTMModule = Sequential()
      .add(BinaryTreeLSTM(
        embeddingDim, hiddenSize, withGraph = true))
      .add(TimeDistributed(Dropout(p)))
      .add(TimeDistributed(Linear(hiddenSize, classNum)))
      .add(TimeDistributed(LogSoftMax()))

    Sequential()
      .add(MapTable(Squeeze(3)))
      .add(ParallelTable()
        .add(embedding)
        .add(Identity()))
      .add(treeLSTMModule)
  }
} 
Example 51
Source File: TextClassifier.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.textclassification

import com.intel.analytics.bigdl.example.utils._
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, _}
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter, T}
import org.apache.log4j.{Level => Levle4j, Logger => Logger4j}
import org.slf4j.{Logger, LoggerFactory}
import scopt.OptionParser

import scala.collection.mutable.{ArrayBuffer, Map => MMap}
import scala.language.existentials

object TextClassifier {
  val log: Logger = LoggerFactory.getLogger(this.getClass)
  LoggerFilter.redirectSparkInfoLogs()
  Logger4j.getLogger("com.intel.analytics.bigdl.optim").setLevel(Levle4j.INFO)

  def main(args: Array[String]): Unit = {
    val localParser = new OptionParser[TextClassificationParams]("BigDL Example") {
      opt[String]('b', "baseDir")
        .required()
        .text("Base dir containing the training and word2Vec data")
        .action((x, c) => c.copy(baseDir = x))
      opt[String]('p', "partitionNum")
        .text("you may want to tune the partitionNum if run into spark mode")
        .action((x, c) => c.copy(partitionNum = x.toInt))
      opt[String]('s', "maxSequenceLength")
        .text("maxSequenceLength")
        .action((x, c) => c.copy(maxSequenceLength = x.toInt))
      opt[String]('w', "maxWordsNum")
        .text("maxWordsNum")
        .action((x, c) => c.copy(maxWordsNum = x.toInt))
      opt[String]('l', "trainingSplit")
        .text("trainingSplit")
        .action((x, c) => c.copy(trainingSplit = x.toDouble))
      opt[String]('z', "batchSize")
        .text("batchSize")
        .action((x, c) => c.copy(batchSize = x.toInt))
      opt[Int]('l', "learningRate")
        .text("learningRate")
        .action((x, c) => c.copy(learningRate = x))
    }

    localParser.parse(args, TextClassificationParams()).map { param =>
      log.info(s"Current parameters: $param")
      val textClassification = new TextClassifier(param)
      textClassification.train()
    }
  }
} 
Example 52
Source File: ShuffleMapTask.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.Properties

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.BlockManagerId


  def this(partitionId: Int) {
    this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null)
  }

  @transient private val preferredLocs: Seq[TaskLocation] = {
    if (locs == null) Nil else locs.toSet.toSeq
  }

  var rdd: RDD[_] = null
  var dep: ShuffleDependency[_, _, _] = null

  override def prepTask(): Unit = {
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (_rdd, _dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
     rdd = _rdd
     dep = _dep
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L
  }

  override def runTask(context: TaskContext): MapStatus = {
    if (dep == null || rdd == null) {
      prepTask()
    }

    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      val status = writer.stop(success = true).get
      FutureTaskNotifier.taskCompleted(status, partitionId, dep.shuffleId,
        dep.partitioner.numPartitions, nextStageLocs, metrics.shuffleWriteMetrics, false)
      status
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

  override def preferredLocations: Seq[TaskLocation] = preferredLocs

  override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}

object ShuffleMapTask {

  def apply(
      stageId: Int,
      stageAttemptId: Int,
      partition: Partition,
      properties: Properties,
      internalAccumulatorsSer: Array[Byte],
      isFutureTask: Boolean,
      rdd: RDD[_],
      dep: ShuffleDependency[_, _, _],
      nextStageLocs: Option[Seq[BlockManagerId]]): ShuffleMapTask = {

    val smt = new ShuffleMapTask(stageId, stageAttemptId, null, partition, null,
      properties, internalAccumulatorsSer, isFutureTask, nextStageLocs)

    smt.rdd = rdd
    smt.dep = dep
    smt
  }
} 
Example 53
Source File: NettyBlockRpcServer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, MapOutputReady, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}


class NettyBlockRpcServer(
    appId: String,
    serializer: Serializer,
    blockManager: BlockDataManager)
  extends RpcHandler with Logging {

  private val streamManager = new OneForOneStreamManager()

  override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocks: Seq[ManagedBuffer] =
          openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

      case uploadBlock: UploadBlock =>
        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
        val (level: StorageLevel, classTag: ClassTag[_]) = {
          serializer
            .newInstance()
            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
            .asInstanceOf[(StorageLevel, ClassTag[_])]
        }
        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
        val blockId = BlockId(uploadBlock.blockId)
        blockManager.putBlockData(blockId, data, level, classTag)
        responseContext.onSuccess(ByteBuffer.allocate(0))

      case mapOutputReady: MapOutputReady =>
        val mapStatus: MapStatus =
          serializer.newInstance().deserialize(ByteBuffer.wrap(mapOutputReady.serializedMapStatus))
        blockManager.mapOutputReady(
          mapOutputReady.shuffleId, mapOutputReady.mapId, mapOutputReady.numReduces, mapStatus)
    }
  }

  override def getStreamManager(): StreamManager = streamManager
} 
Example 54
Source File: ArrayBasedMapData.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.util.{Map => JavaMap}

class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData {
  require(keyArray.numElements() == valueArray.numElements())

  override def numElements(): Int = keyArray.numElements()

  override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())

  override def toString: String = {
    s"keys: $keyArray, values: $valueArray"
  }
}

object ArrayBasedMapData {
  
  def apply(
      iterator: Iterator[(_, _)],
      size: Int,
      keyConverter: (Any) => Any,
      valueConverter: (Any) => Any): ArrayBasedMapData = {

    val keys: Array[Any] = new Array[Any](size)
    val values: Array[Any] = new Array[Any](size)

    var i = 0
    for ((key, value) <- iterator) {
      keys(i) = keyConverter(key)
      values(i) = valueConverter(value)
      i += 1
    }
    ArrayBasedMapData(keys, values)
  }

  def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = {
    new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
  }

  def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
    val keys = map.keyArray.asInstanceOf[GenericArrayData].array
    val values = map.valueArray.asInstanceOf[GenericArrayData].array
    keys.zip(values).toMap
  }

  def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
    keys.zip(values).toMap
  }

  def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
    keys.zip(values).toMap
  }

  def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
    import scala.collection.JavaConverters._
    keys.zip(values).toMap.asJava
  }
} 
Example 55
Source File: FPTreeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 56
Source File: SNV.scala    From seqspark   with Apache License 2.0 5 votes vote down vote up
package org.dizhang.seqspark.assoc

import breeze.stats.distributions.Gaussian
import org.dizhang.seqspark.stat.HypoTest.{NullModel => NM}
import org.dizhang.seqspark.stat.{Resampling, ScoreTest, WaldTest}
import org.dizhang.seqspark.util.General._

import scala.language.existentials


trait SNV extends AssocMethod {
  def nullModel: NM
  def x: Encode.Common
  def result: AssocMethod.Result
}

object SNV {
  def apply(nullModel: NM,
            x: Encode.Coding): SNV with AssocMethod.AnalyticTest = {
    nullModel match {
      case nm: NM.Fitted =>
        AnalyticScoreTest(nm, x.asInstanceOf[Encode.Common])
      case _ =>
        AnalyticWaldTest(nullModel, x.asInstanceOf[Encode.Common])
    }
  }

  def apply(ref: Double, min: Int, max: Int,
            nullModel: NM.Fitted,
            x: Encode.Coding): ResamplingTest = {
    ResamplingTest(ref, min, max, nullModel, x.asInstanceOf[Encode.Common])
  }

  def getStatistic(nm: NM.Fitted, x: Encode.Coding): Double = {
    val st = ScoreTest(nm, x.asInstanceOf[Encode.Common].coding)
    math.abs(st.score(0)/st.variance(0,0).sqrt)
  }

  @SerialVersionUID(7727280101L)
  final case class AnalyticScoreTest(nullModel: NM.Fitted,
                                     x: Encode.Common)
    extends SNV with AssocMethod.AnalyticTest
  {
    //val scoreTest = ScoreTest(nullModel, x.coding)
    val statistic = getStatistic(nullModel, x)
    val pValue = {
      val dis = new Gaussian(0.0, 1.0)
      Some((1.0 - dis.cdf(statistic)) * 2)
    }

    def result: AssocMethod.BurdenAnalytic = {
      AssocMethod.BurdenAnalytic(x.vars, statistic, pValue, "test=score")
    }
  }

  case class AnalyticWaldTest(nullModel: NM,
                              x: Encode.Common)
    extends SNV with AssocMethod.AnalyticTest
  {
    private val wt = WaldTest(nullModel, x.coding.toDenseVector)
    val statistic = wt.beta(1) / wt.std(1)
    val pVaue = Some(wt.pValue(oneSided = false).apply(1))
    def result = {
      AssocMethod.BurdenAnalytic(x.vars, statistic, pVaue, s"test=wald;beta=${wt.beta(1)};betaStd=${wt.std(1)}")
    }
  }

  @SerialVersionUID(7727280201L)
  final case class ResamplingTest(refStatistic: Double,
                                  min: Int,
                                  max: Int,
                                  nullModel: NM.Fitted,
                                  x: Encode.Common)
    extends SNV with AssocMethod.ResamplingTest
  {
    def pCount = Resampling.Simple(refStatistic, min, max, nullModel, x, getStatistic).pCount
    def result: AssocMethod.BurdenResampling = {
      AssocMethod.BurdenResampling(x.vars, refStatistic, pCount)
    }
  }
} 
Example 57
Source File: ShuffleMapTask.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.nio.ByteBuffer

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter


  def this(partitionId: Int) {
    this(0, null, new Partition { override def index = 0 }, null)
  }

  @transient private val preferredLocs: Seq[TaskLocation] = {
    if (locs == null) Nil else locs.toSet.toSeq
  }

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

  override def preferredLocations: Seq[TaskLocation] = preferredLocs

  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
} 
Example 58
Source File: NettyBlockRpcServer.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}


class NettyBlockRpcServer(
    appId: String,
    serializer: Serializer,
    blockManager: BlockDataManager)
  extends RpcHandler with Logging {

  private val streamManager = new OneForOneStreamManager()

  override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocks: Seq[ManagedBuffer] =
          openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

      case uploadBlock: UploadBlock =>
        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
        val (level: StorageLevel, classTag: ClassTag[_]) = {
          serializer
            .newInstance()
            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
            .asInstanceOf[(StorageLevel, ClassTag[_])]
        }
        val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
        val blockId = BlockId(uploadBlock.blockId)
        blockManager.putBlockData(blockId, data, level, classTag)
        responseContext.onSuccess(ByteBuffer.allocate(0))
    }
  }

  override def getStreamManager(): StreamManager = streamManager
} 
Example 59
Source File: FPTreeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.fpm

import scala.language.existentials

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext

class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

  test("add transaction") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    assert(tree.root.children.size == 2)
    assert(tree.root.children.contains("a"))
    assert(tree.root.children("a").item.equals("a"))
    assert(tree.root.children("a").count == 2)
    assert(tree.root.children.contains("b"))
    assert(tree.root.children("b").item.equals("b"))
    assert(tree.root.children("b").count == 1)
    var child = tree.root.children("a")
    assert(child.children.size == 1)
    assert(child.children.contains("b"))
    assert(child.children("b").item.equals("b"))
    assert(child.children("b").count == 2)
    child = child.children("b")
    assert(child.children.size == 2)
    assert(child.children.contains("c"))
    assert(child.children.contains("y"))
    assert(child.children("c").item.equals("c"))
    assert(child.children("y").item.equals("y"))
    assert(child.children("c").count == 1)
    assert(child.children("y").count == 1)
  }

  test("merge tree") {
    val tree1 = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("b"))

    val tree2 = new FPTree[String]
      .add(Seq("a", "b"))
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "c", "d"))
      .add(Seq("a", "x"))
      .add(Seq("a", "x", "y"))
      .add(Seq("c", "n"))
      .add(Seq("c", "m"))

    val tree3 = tree1.merge(tree2)

    assert(tree3.root.children.size == 3)
    assert(tree3.root.children("a").count == 7)
    assert(tree3.root.children("b").count == 1)
    assert(tree3.root.children("c").count == 2)
    val child1 = tree3.root.children("a")
    assert(child1.children.size == 2)
    assert(child1.children("b").count == 5)
    assert(child1.children("x").count == 2)
    val child2 = child1.children("b")
    assert(child2.children.size == 2)
    assert(child2.children("y").count == 1)
    assert(child2.children("c").count == 3)
    val child3 = child2.children("c")
    assert(child3.children.size == 1)
    assert(child3.children("d").count == 1)
    val child4 = child1.children("x")
    assert(child4.children.size == 1)
    assert(child4.children("y").count == 1)
    val child5 = tree3.root.children("c")
    assert(child5.children.size == 2)
    assert(child5.children("n").count == 1)
    assert(child5.children("m").count == 1)
  }

  test("extract freq itemsets") {
    val tree = new FPTree[String]
      .add(Seq("a", "b", "c"))
      .add(Seq("a", "b", "y"))
      .add(Seq("a", "b"))
      .add(Seq("a"))
      .add(Seq("b"))
      .add(Seq("b", "n"))

    val freqItemsets = tree.extract(3L).map { case (items, count) =>
      (items.toSet, count)
    }.toSet
    val expected = Set(
      (Set("a"), 4L),
      (Set("b"), 5L),
      (Set("a", "b"), 3L))
    assert(freqItemsets === expected)
  }
} 
Example 60
Source File: RpcMessages.scala    From spark-monitoring   with MIT License 5 votes vote down vote up
package org.apache.spark.metrics

import java.util.concurrent.TimeUnit

import com.codahale.metrics.{Clock, Reservoir}

trait MetricMessage[T] {
  val namespace: String
  val metricName: String
  val value: T
}

private[metrics] case class CounterMessage(
                                          override val namespace: String,
                                          override val metricName: String,
                                          override val value: Long
                                          ) extends MetricMessage[Long]

private[metrics] case class SettableGaugeMessage[T](
                                                                   override val namespace: String,
                                                                   override val metricName: String,
                                                                   override val value: T
                                                                   ) extends MetricMessage[T]

import scala.language.existentials

private[metrics] case class HistogramMessage(
                                            override val namespace: String,
                                            override val metricName: String,
                                            override val value: Long,
                                            reservoirClass: Class[_ <: Reservoir]
                                            ) extends MetricMessage[Long]

private[metrics] case class MeterMessage(
                                        override val namespace: String,
                                        override val metricName: String,
                                        override val value: Long,
                                        clockClass: Class[_ <: Clock]
                                        ) extends MetricMessage[Long]

private[metrics] case class TimerMessage(
                                        override val namespace: String,
                                        override val metricName: String,
                                        override val value: Long,
                                        timeUnit: TimeUnit,
                                        reservoirClass: Class[_ <: Reservoir],
                                        clockClass: Class[_ <: Clock]
                                        ) extends MetricMessage[Long] 
Example 61
Source File: ScheduledTaskManager.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import scala.language.existentials
import java.util.concurrent._
import java.util.UUID
import com.google.common.util.concurrent.ThreadFactoryBuilder
import ScheduledTaskManager._
import scala.util.Try


  def stop() = {
    _taskMap.clear()
    _scheduler.shutdown()
  }
}

object ScheduledTaskManager {
  val DefaultMaxThreads = 4
  val DefaultExecutionDelay = 10 // 10 milliseconds
  val DefaultTimeInterval = 100 // 100 milliseconds
} 
Example 62
Source File: Resampling.scala    From seqspark   with Apache License 2.0 5 votes vote down vote up
package org.dizhang.seqspark.stat

import breeze.linalg.{DenseVector, shuffle}
import breeze.stats.distributions.Bernoulli
import org.dizhang.seqspark.assoc.Encode
import org.dizhang.seqspark.ds.SemiGroup.PairInt
import org.dizhang.seqspark.stat.HypoTest.NullModel

import scala.language.existentials


  def makeNewNullModel: NullModel.Fitted = {
    val newY = makeNewY()
    val cols = nullModel.xs.cols
    NullModel(
      newY,
      nullModel.xs(::, 1 until cols),
      fit = true,
      binary = nullModel.binary
    ).asInstanceOf[NullModel.Fitted]
  }
} 
Example 63
Source File: VT.scala    From seqspark   with Apache License 2.0 5 votes vote down vote up
package org.dizhang.seqspark.assoc

import breeze.linalg._
import org.dizhang.seqspark.stat.HypoTest.{NullModel => NM}
import org.dizhang.seqspark.stat.{Resampling, ScoreTest}
import org.dizhang.seqspark.util.General.RichDouble
import org.slf4j.LoggerFactory

import scala.language.existentials


@SerialVersionUID(7727880001L)
trait VT extends AssocMethod {
  def nullModel: NM
  def x: Encode.VT
  def result: AssocMethod.Result
}

object VT {

  val logger = LoggerFactory.getLogger(getClass)

  def apply(nullModel: NM,
            x: Encode.Coding): VT with AssocMethod.AnalyticTest = {
    val nmf = nullModel match {
      case NM.Simple(y, b) => NM.Fit(y, b)
      case NM.Mutiple(y, c, b) => NM.Fit(y, c, b)
      case nm: NM.Fitted => nm
    }
    AnalyticScoreTest(nmf, x.asInstanceOf[Encode.VT])
  }

  def apply(ref: Double, min: Int, max: Int,
            nullModel: NM.Fitted,
            x: Encode.Coding): ResamplingTest = {
    ResamplingTest(ref, min, max, nullModel, x.asInstanceOf[Encode.VT])
  }

  def getStatistic(nm: NM.Fitted, x: Encode.Coding): Double = {
    //println(s"scores: ${st.score.toArray.mkString(",")}")
    //println(s"variances: ${diag(st.variance).toArray.mkString(",")}")
    val m = x.asInstanceOf[Encode.VT].coding
    val ts = m.map{sv =>
      val st = ScoreTest(nm, sv)
      st.score(0)/st.variance(0, 0).sqrt
    }
    //val ts = st.score :/ diag(st.variance).map(x => x.sqrt)
    max(ts)
  }

  @SerialVersionUID(7727880101L)
  final case class AnalyticScoreTest(nullModel: NM.Fitted,
                                     x: Encode.VT)
    extends VT with AssocMethod.AnalyticTest
  {

    val statistic = getStatistic(nullModel, x)
    val pValue = None
    def result: AssocMethod.VTAnalytic = {
      val info = s"MAFs=${x.coding.length}"
      AssocMethod.VTAnalytic(x.vars, x.size, statistic, pValue, info)
    }
  }

  @SerialVersionUID(7727880201L)
  final case class ResamplingTest(refStatistic: Double,
                                  min: Int,
                                  max: Int,
                                  nullModel: NM.Fitted,
                                  x: Encode.VT)
    extends VT with AssocMethod.ResamplingTest
  {
    def pCount = {
      Resampling.Simple(refStatistic, min, max, nullModel, x, getStatistic).pCount
    }
    def result: AssocMethod.VTResampling =
      AssocMethod.VTResampling(x.vars, x.size, refStatistic, pCount)
  }

} 
Example 64
Source File: Burden.scala    From seqspark   with Apache License 2.0 5 votes vote down vote up
package org.dizhang.seqspark.assoc

import breeze.linalg.DenseVector
import breeze.stats.distributions.{Gaussian, StudentsT}
import org.dizhang.seqspark.stat.HypoTest.{NullModel => NM}
import org.dizhang.seqspark.stat.{Resampling, ScoreTest, WaldTest}
import org.dizhang.seqspark.util.General._

import scala.language.existentials


@SerialVersionUID(7727280001L)
trait Burden extends AssocMethod {
  def nullModel: NM
  def x: Encode.Fixed
  def result: AssocMethod.Result
}

object Burden {

  def apply(nullModel: NM,
            x: Encode.Coding): Burden with AssocMethod.AnalyticTest = {
    nullModel match {
      case nm: NM.Fitted =>
        AnalyticScoreTest(nm, x.asInstanceOf[Encode.Fixed])
      case _ =>
        AnalyticWaldTest(nullModel, x.asInstanceOf[Encode.Fixed])
    }
  }

  def apply(ref: Double, min: Int, max: Int,
            nullModel: NM.Fitted,
            x: Encode.Coding): ResamplingTest = {
    ResamplingTest(ref, min, max, nullModel, x.asInstanceOf[Encode.Fixed])
  }

  def getStatistic(nm: NM.Fitted, x: Encode.Coding): Double = {
    val st = ScoreTest(nm, x.asInstanceOf[Encode.Fixed].coding)
    st.score(0)/st.variance(0,0).sqrt
  }

  def getStatistic(nm: NM, x: DenseVector[Double]): Double = {
    val wt = WaldTest(nm, x)
    (wt.beta /:/ wt.std).apply(1)
  }

  @SerialVersionUID(7727280101L)
  final case class AnalyticScoreTest(nullModel: NM.Fitted,
                                     x: Encode.Fixed)
    extends Burden with AssocMethod.AnalyticTest
  {
    def geno = x.coding
    //val scoreTest = ScoreTest(nullModel, geno)
    val statistic = getStatistic(nullModel, x)
    val pValue = {
      val dis = new Gaussian(0.0, 1.0)
      Some(1.0 - dis.cdf(statistic))
    }

    def result: AssocMethod.BurdenAnalytic = {
      AssocMethod.BurdenAnalytic(x.vars, statistic, pValue, "test=score")
    }

  }
  case class AnalyticWaldTest(nullModel: NM,
                              x: Encode.Fixed) extends Burden with AssocMethod.AnalyticTest {
    def geno = x.coding
    private val wt = WaldTest(nullModel, x.coding)
    val statistic = getStatistic(nullModel, geno)
    val pValue = {
      val dis = new StudentsT(nullModel.dof - 1)
      Some(1.0 - dis.cdf(statistic))
    }
    def result = {
      AssocMethod.BurdenAnalytic(x.vars, statistic, pValue, s"test=wald;beta=${wt.beta(1)};betaStd=${wt.std(1)}")
    }
  }

  @SerialVersionUID(7727280201L)
  final case class ResamplingTest(refStatistic: Double,
                                  min: Int,
                                  max: Int,
                                  nullModel: NM.Fitted,
                                  x: Encode.Fixed)
    extends Burden with AssocMethod.ResamplingTest
  {
    def pCount = Resampling.Simple(refStatistic, min, max, nullModel, x, getStatistic).pCount
    def result: AssocMethod.BurdenResampling = {
      AssocMethod.BurdenResampling(x.vars, refStatistic, pCount)
    }
  }
} 
Example 65
Source File: BytecodeUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.graphx.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.HashSet
import scala.language.existentials

import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor}
import org.apache.xbean.asm5.Opcodes._

import org.apache.spark.util.Utils


  private class MethodInvocationFinder(className: String, methodName: String)
    extends ClassVisitor(ASM5) {

    val methodsInvoked = new HashSet[(Class[_], String)]

    override def visitMethod(access: Int, name: String, desc: String,
                             sig: String, exceptions: Array[String]): MethodVisitor = {
      if (name == methodName) {
        new MethodVisitor(ASM5) {
          override def visitMethodInsn(
              op: Int, owner: String, name: String, desc: String, itf: Boolean) {
            if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
              if (!skipClass(owner)) {
                methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name))
              }
            }
          }
        }
      } else {
        null
      }
    }
  }
} 
Example 66
Source File: MyNettyBlockRpcServer.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.language.existentials

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.remote.{HadoopFileSegmentManagedBuffer, MessageForHadoopManagedBuffers, RemoteShuffleManager}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage.{BlockId, ShuffleBlockId}


class MyNettyBlockRpcServer(
    appId: String,
    serializer: Serializer,
    blockManager: BlockDataManager)
  extends RpcHandler with Logging {

  private val streamManager = new OneForOneStreamManager()

  override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocksNum = openBlocks.blockIds.length
        val isShuffleRequest = (blocksNum > 0) &&
          BlockId.apply(openBlocks.blockIds(0)).isInstanceOf[ShuffleBlockId] &&
          (SparkEnv.get.conf.get("spark.shuffle.manager", classOf[SortShuffleManager].getName)
            == classOf[RemoteShuffleManager].getName)
        if (isShuffleRequest) {
          val blockIdAndManagedBufferPair =
            openBlocks.blockIds.map(block => (block, blockManager.getHostLocalShuffleData(
              BlockId.apply(block), Array.empty).asInstanceOf[HadoopFileSegmentManagedBuffer]))
          responseContext.onSuccess(new MessageForHadoopManagedBuffers(
            blockIdAndManagedBufferPair).toByteBuffer.nioBuffer())
        } else {
          // This customized Netty RPC server is only served for RemoteShuffle requests,
          // Other RPC messages or data chunks transferring should go through
          // NettyBlockTransferService' NettyBlockRpcServer
          throw new UnsupportedOperationException("MyNettyBlockRpcServer only serves remote" +
            " shuffle requests for OpenBlocks")
        }

      case uploadBlock: UploadBlock =>
        throw new UnsupportedOperationException("MyNettyBlockRpcServer doesn't serve UploadBlock")
    }
  }

  override def receiveStream(
      client: TransportClient,
      messageHeader: ByteBuffer,
      responseContext: RpcResponseCallback): StreamCallbackWithID = {
    throw new UnsupportedOperationException("MyNettyBlockRpcServer doesn't support receiving" +
      " stream")
  }

  override def getStreamManager(): StreamManager = streamManager
} 
Example 67
Source File: RowToVectorBuilder.scala    From filo   with Apache License 2.0 5 votes vote down vote up
package org.velvia.filo

import java.nio.ByteBuffer
import scala.language.existentials
import scala.language.postfixOps
import scalaxy.loops._

import BuilderEncoder.{EncodingHint, AutoDetect}

case class VectorInfo(name: String, dataType: Class[_])

// To help matching against the ClassTag in the VectorBuilder
private object Classes {
  val Boolean = classOf[Boolean]
  val Byte = java.lang.Byte.TYPE
  val Short = java.lang.Short.TYPE
  val Int = java.lang.Integer.TYPE
  val Long = java.lang.Long.TYPE
  val Float = java.lang.Float.TYPE
  val Double = java.lang.Double.TYPE
  val String = classOf[String]
  val DateTime = classOf[org.joda.time.DateTime]
  val SqlTimestamp = classOf[java.sql.Timestamp]
  val UTF8 = classOf[ZeroCopyUTF8String]
}

object RowToVectorBuilder {
  
  def convertToBytes(hint: EncodingHint = AutoDetect): Map[String, ByteBuffer] = {
    val chunks = builders.map(_.toFiloBuffer(hint))
    schema.zip(chunks).map { case (VectorInfo(colName, _), bytes) => (colName, bytes) }.toMap
  }

  private def unsupportedInput(typ: Any) =
    throw new RuntimeException("Unsupported input type " + typ)
} 
Example 68
Source File: FlinkScalarFunctionGenerator.scala    From milan   with Apache License 2.0 5 votes vote down vote up
package com.amazon.milan.compiler.flink.internal

import com.amazon.milan.compiler.scala.{CodeBlock, DefaultTypeEmitter, ScalarFunctionGenerator, TypeEmitter}
import com.amazon.milan.compiler.flink.generator.FlinkGeneratorException
import com.amazon.milan.compiler.flink.typeutil._
import com.amazon.milan.program.ValueDef
import com.amazon.milan.types._
import com.amazon.milan.typeutil.{TypeDescriptor, types}

import scala.language.existentials


object FlinkScalarFunctionGenerator {
  val default = new FlinkScalarFunctionGenerator(new DefaultTypeEmitter)
}


case class FunctionParts(arguments: CodeBlock, returnType: CodeBlock, body: CodeBlock)


class FlinkScalarFunctionGenerator(typeEmitter: TypeEmitter) extends ScalarFunctionGenerator(typeEmitter, ContextualTreeTransformer) {

  
  private class ArrayFieldConversionContext(tupleType: TypeDescriptor[_]) extends ConversionContext {
    override def generateSelectTermAndContext(name: String): (String, ConversionContext) = {
      if (name == RecordIdFieldName) {
        // RecordId is a special field for tuple streams, because it's a property of the ArrayRecord class rather than
        // being present in the fields array itself.
        (s".$name", createContextForType(types.String))
      }
      else {
        val fieldIndex = this.tupleType.fields.takeWhile(_.name != name).length

        if (fieldIndex >= this.tupleType.fields.length) {
          throw new FlinkGeneratorException(s"Field '$name' not found.")
        }

        val fieldType = this.tupleType.fields(fieldIndex).fieldType
        (s"($fieldIndex).asInstanceOf[${typeEmitter.getTypeFullName(fieldType)}]", createContextForType(fieldType))
      }
    }
  }

  override protected def createContextForArgument(valueDef: ValueDef): ConversionContext = {
    // If the record type is a tuple with named fields then this is a tuple stream whose records are stored as
    // ArrayRecord objects.
    if (valueDef.tpe.isTupleRecord) {
      new ArrayArgumentConversionContext(valueDef.name, valueDef.tpe)
    }
    else {
      super.createContextForArgument(valueDef)
    }
  }

  override protected def createContextForType(contextType: TypeDescriptor[_]): ConversionContext = {
    // If the context type is a tuple with named fields then term names must be mapped to indices in the ArrayRecord
    // objects.
    if (contextType.isTupleRecord) {
      new ArrayFieldConversionContext(contextType)
    }
    else {
      super.createContextForType(contextType)
    }
  }
} 
Example 69
Source File: TestWindow.scala    From milan   with Apache License 2.0 5 votes vote down vote up
package com.amazon.milan.lang

import java.time.Duration

import com.amazon.milan.lang.aggregation._
import com.amazon.milan.program
import com.amazon.milan.program.{GroupBy, _}
import com.amazon.milan.test.{DateIntRecord, DateKeyValueRecord}
import com.amazon.milan.typeutil.{FieldDescriptor, types}
import org.junit.Assert._
import org.junit.Test

import scala.language.existentials


@Test
class TestWindow {
  @Test
  def test_TumblingWindow_ReturnsStreamWithCorrectInputNodeAndWindowProperties(): Unit = {
    val stream = Stream.of[DateIntRecord]
    val windowed = stream.tumblingWindow(r => r.dateTime, Duration.ofHours(1), Duration.ofMinutes(30))

    val TumblingWindow(_, dateExtractorFunc, period, offset) = windowed.expr

    // If this extraction doesn't throw an exception then the formula is correct.
    val FunctionDef(List(ValueDef("r", _)), SelectField(SelectTerm("r"), "dateTime")) = dateExtractorFunc

    assertEquals(Duration.ofHours(1), period.asJava)
    assertEquals(Duration.ofMinutes(30), offset.asJava)
  }

  @Test
  def test_TumblingWindow_ThenSelectToTuple_ReturnsStreamWithCorrectFieldComputationExpression(): Unit = {
    val stream = Stream.of[DateIntRecord]
    val grouped = stream.tumblingWindow(r => r.dateTime, Duration.ofHours(1), Duration.ofMinutes(30))
    val selected = grouped.select((key, r) => fields(field("max", max(r.i))))

    val Aggregate(source, FunctionDef(_, NamedFields(fieldList))) = selected.expr

    assertEquals(1, selected.recordType.fields.length)
    assertEquals(FieldDescriptor("max", types.Int), selected.recordType.fields.head)

    assertEquals(1, fieldList.length)
    assertEquals("max", fieldList.head.fieldName)

    // If this extraction statement doesn't crash then we're good.
    val Max(SelectField(SelectTerm("r"), "i")) = fieldList.head.expr
  }

  @Test
  def test_SlidingWindow_ReturnsStreamWithCorrectInputNodeAndWindowProperties(): Unit = {
    val stream = Stream.of[DateIntRecord]
    val windowed = stream.slidingWindow(r => r.dateTime, Duration.ofHours(1), Duration.ofMinutes(10), Duration.ofMinutes(30))

    val SlidingWindow(_, dateExtractorFunc, size, slide, offset) = windowed.expr

    val FunctionDef(List(ValueDef("r", _)), SelectField(SelectTerm("r"), "dateTime")) = dateExtractorFunc

    assertEquals(Duration.ofHours(1), size.asJava)
    assertEquals(Duration.ofMinutes(10), slide.asJava)
    assertEquals(Duration.ofMinutes(30), offset.asJava)
  }

  @Test
  def test_GroupBy_ThenTumblingWindow_ThenSelect_ReturnsStreamWithCorrectInputNodeAndWindowProperties(): Unit = {
    val input = Stream.of[DateKeyValueRecord].withId("input")
    val output = input.groupBy(r => r.key)
      .tumblingWindow(r => r.dateTime, Duration.ofMinutes(5), Duration.ZERO)
      .select((windowStart, r) => any(r))

    val Aggregate(windowExpr, FunctionDef(List(ValueDef("windowStart", _), ValueDef("r", _)), First(SelectTerm("r")))) = output.expr
    val TumblingWindow(groupExpr, FunctionDef(List(ValueDef("r", _)), SelectField(SelectTerm("r"), "dateTime")), program.Duration(300000), program.Duration(0)) = windowExpr
    val GroupBy(ExternalStream("input", "input", _), FunctionDef(List(ValueDef("r", _)), SelectField(SelectTerm("r"), "key"))) = groupExpr
  }
} 
Example 70
Source File: Surface.scala    From airframe   with Apache License 2.0 5 votes vote down vote up
package wvlet.airframe.surface

import scala.language.existentials
import scala.language.experimental.macros

object Surface {
  def of[A]: Surface = macro SurfaceMacros.surfaceOf[A]
  def methodsOf[A]: Seq[MethodSurface] = macro SurfaceMacros.methodSurfaceOf[A]
}

trait Surface extends Serializable {
  def rawType: Class[_]
  def typeArgs: Seq[Surface]
  def params: Seq[Parameter]
  def name: String
  def fullName: String

  def dealias: Surface = this

  def isOption: Boolean
  def isAlias: Boolean
  def isPrimitive: Boolean
  def isSeq: Boolean = classOf[Seq[_]].isAssignableFrom(rawType)

  def objectFactory: Option[ObjectFactory] = None
}

sealed trait ParameterBase extends Serializable {
  def name: String
  def surface: Surface

  def call(obj: Any, x: Any*): Any
}

trait Parameter extends ParameterBase {
  def index: Int
  def name: String

  
  def getMethodArgDefaultValue(methodOwner: Any): Option[Any] = getDefaultValue
}

trait MethodSurface extends ParameterBase {
  def mod: Int
  def owner: Surface
  def name: String
  def args: Seq[MethodParameter]
  def surface: Surface = returnType
  def returnType: Surface

  def isPublic: Boolean    = (mod & MethodModifier.PUBLIC) != 0
  def isPrivate: Boolean   = (mod & MethodModifier.PRIVATE) != 0
  def isProtected: Boolean = (mod & MethodModifier.PROTECTED) != 0
  def isStatic: Boolean    = (mod & MethodModifier.STATIC) != 0
  def isFinal: Boolean     = (mod & MethodModifier.FINAL) != 0
  def isAbstract: Boolean  = (mod & MethodModifier.ABSTRACT) != 0
} 
Example 71
Source File: AirframeException.scala    From airframe   with Apache License 2.0 5 votes vote down vote up
package wvlet.airframe

import wvlet.airframe.surface.Surface
import scala.language.existentials

trait AirframeException extends Exception { self =>

  
  def getCode: String           = this.getClass.getSimpleName
  override def toString: String = getMessage
}

object AirframeException {
  case class MISSING_SESSION(cl: Class[_]) extends AirframeException {
    override def getMessage: String =
      s"[$getCode] Session is not found inside ${cl}. You may need to define ${cl} as a trait or implement DISupport to inject the current Session."
  }
  case class CYCLIC_DEPENDENCY(deps: List[Surface], sourceCode: SourceCode) extends AirframeException {
    override def getMessage: String = s"[$getCode] ${deps.reverse.mkString(" -> ")} at ${sourceCode}"
  }
  case class MISSING_DEPENDENCY(stack: List[Surface], sourceCode: SourceCode) extends AirframeException {
    override def getMessage: String =
      s"[$getCode] Binding for ${stack.head} at ${sourceCode} is not found: ${stack.mkString(" <- ")}"
  }

  case class SHUTDOWN_FAILURE(cause: Throwable) extends AirframeException {
    override def getMessage: String = {
      s"[${getCode}] Failure at session shutdown: ${cause.getMessage}"
    }
  }
  case class MULTIPLE_SHUTDOWN_FAILURES(causes: List[Throwable]) extends AirframeException {
    override def getMessage: String = {
      s"[${getCode}] Multiple failures occurred during session shutdown:\n${causes.map(x => s"  - ${x.getMessage}").mkString("\n")}"
    }
  }
} 
Example 72
Source File: RunServer.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.grpc.server

import java.util.concurrent.{Executors, TimeUnit}

import akka.Done
import akka.actor.{ActorSystem, CoordinatedShutdown}
import akka.stream.{ActorMaterializer, Materializer}
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import io.grpc.ServerBuilder
import ml.combust.mleap.executor.MleapExecutor
import ml.combust.mleap.pb.MleapGrpc

import scala.concurrent.{ExecutionContext, Future}
import scala.language.existentials
import scala.util.{Failure, Success, Try}

class RunServer(config: Config)
               (implicit system: ActorSystem) {
  private val logger = Logger(classOf[RunServer])

  private var coordinator: Option[CoordinatedShutdown] = None

  def run(): Unit = {
    Try {
      logger.info("Starting MLeap gRPC Server")

      val coordinator = CoordinatedShutdown(system)
      this.coordinator = Some(coordinator)

      implicit val materializer: Materializer = ActorMaterializer()

      val grpcServerConfig = new GrpcServerConfig(config.getConfig("default"))
      val mleapExecutor = MleapExecutor(system)
      val port: Int = config.getInt("port")
      val threads: Option[Int] = if (config.hasPath("threads")) Some(config.getInt("threads")) else None
      val threadCount = threads.getOrElse {
        Math.min(Math.max(Runtime.getRuntime.availableProcessors() * 4, 32), 64)
      }

      logger.info(s"Creating thread pool for server with size $threadCount")
      val grpcThreadPool = Executors.newFixedThreadPool(threadCount)
      implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(grpcThreadPool)

      coordinator.addTask(CoordinatedShutdown.PhaseServiceRequestsDone, "threadPoolShutdownNow") {
        () =>
          Future {
            logger.info("Shutting down gRPC thread pool")
            grpcThreadPool.shutdown()
            grpcThreadPool.awaitTermination(5, TimeUnit.SECONDS)

            Done
          }
      }

      logger.info(s"Creating executor service")
      val grpcService: GrpcServer = new GrpcServer(mleapExecutor, grpcServerConfig)
      val builder = ServerBuilder.forPort(port)
      builder.intercept(new ErrorInterceptor)
      builder.addService(MleapGrpc.bindService(grpcService, ec))
      val grpcServer = builder.build()

      logger.info(s"Starting server on port $port")
      grpcServer.start()

      coordinator.addTask(CoordinatedShutdown.PhaseServiceUnbind, "grpcServiceShutdown") {
        () =>
          Future {
            logger.info("Shutting down gRPC")
            grpcServer.shutdown()
            grpcServer.awaitTermination(10, TimeUnit.SECONDS)
            Done
          }(ExecutionContext.global)
      }

      coordinator.addTask(CoordinatedShutdown.PhaseServiceStop, "grpcServiceShutdownNow") {
        () =>
          Future {
            if (!grpcServer.isShutdown) {
              logger.info("Shutting down gRPC NOW!")

              grpcServer.shutdownNow()
              grpcServer.awaitTermination(5, TimeUnit.SECONDS)
            }

            Done
          }(ExecutionContext.global)
      }
    } match {
      case Success(_) =>
      case Failure(err) =>
        logger.error("Error encountered starting server", err)
        for (c <- this.coordinator) {
          c.run(CoordinatedShutdown.UnknownReason)
        }
        throw err
    }
  }
} 
Example 73
Source File: Responses.scala    From finagle-postgres   with Apache License 2.0 5 votes vote down vote up
package com.twitter.finagle.postgres

import java.nio.charset.Charset

import com.twitter.finagle.postgres.messages.{DataRow, Field}
import com.twitter.finagle.postgres.values.ValueDecoder
import com.twitter.util.Try
import Try._
import com.twitter.concurrent.AsyncStream
import com.twitter.finagle.postgres.PostgresClient.TypeSpecifier
import com.twitter.finagle.postgres.codec.NullValue
import io.netty.buffer.ByteBuf

import scala.language.existentials

// capture all common format data for a set of rows to reduce repeated references
case class RowFormat(
  indexMap: Map[String, Int],
  formats: Array[Short],
  oids: Array[Int],
  dataTypes: Map[Int, TypeSpecifier],
  receives: PartialFunction[String, ValueDecoder[T] forSome {type T}],
  charset: Charset
) {
  @inline final def recv(index: Int) = dataTypes(oids(index)).receiveFunction
  @inline final def defaultDecoder(index: Int) = receives.applyOrElse(recv(index), (_: String) => ValueDecoder.never)
}

trait Row {
  def getOption[T](name: String)(implicit decoder: ValueDecoder[T]): Option[T]
  def getOption[T](index: Int)(implicit decoder: ValueDecoder[T]): Option[T]
  def get[T](name: String)(implicit decoder: ValueDecoder[T]): T
  def get[T](index: Int)(implicit decoder: ValueDecoder[T]): T
  def getTry[T](name: String)(implicit decoder: ValueDecoder[T]): Try[T]
  def getTry[T](index: Int)(implicit decoder: ValueDecoder[T]): Try[T]
  def getOrElse[T](name: String, default: => T)(implicit decoder: ValueDecoder[T]): T
  def getOrElse[T](index: Int, default: => T)(implicit decoder: ValueDecoder[T]): T
  def getAnyOption(name: String): Option[Any]
  def getAnyOption(index: Int): Option[Any]
}

object Row {
  def apply(values: Array[Option[ByteBuf]], rowFormat: RowFormat): Row = RowImpl(values, rowFormat)
}


object ResultSet {
  def apply(
    fields: Array[Field],
    charset: Charset,
    dataRows: AsyncStream[DataRow],
    types: Map[Int, TypeSpecifier],
    receives: PartialFunction[String, ValueDecoder[T] forSome { type T }]
  ): ResultSet = {
    val (indexMap, formats, oids) = {
      val l = fields.length
      val stringIndex = new Array[(String, Int)](l)
      val formats = new Array[Short](l)
      val oids = new Array[Int](l)
      var i = 0
      while(i < l) {
        val Field(name, format, dataType) = fields(i)
        stringIndex(i) = (name, i)
        formats(i) = format
        oids(i) = dataType
        i += 1
      }
      (stringIndex.toMap, formats, oids)
    }

    val rowFormat = RowFormat(indexMap, formats, oids, types, receives, charset)

    val rows = dataRows.map {
      dataRow => Row(
        values = dataRow.data,
        rowFormat = rowFormat
      )
    }

    ResultSet(rows)
  }
}