scala.collection.mutable.ArrayBuffer Scala Examples

The following examples show how to use scala.collection.mutable.ArrayBuffer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: IntegrationTest.scala    From kmq   with Apache License 2.0 6 votes vote down vote up
package com.softwaremill.kmq.redelivery

import java.time.Duration
import java.util.Random

import akka.actor.ActorSystem
import akka.kafka.scaladsl.{Consumer, Producer}
import akka.kafka.{ConsumerSettings, ProducerMessage, ProducerSettings, Subscriptions}
import akka.stream.ActorMaterializer
import akka.testkit.TestKit
import com.softwaremill.kmq._
import com.softwaremill.kmq.redelivery.infrastructure.KafkaSpec
import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
import org.apache.kafka.common.serialization.StringDeserializer
import org.scalatest.concurrent.Eventually
import org.scalatest.time.{Seconds, Span}
import org.scalatest.{BeforeAndAfterAll, FlatSpecLike, Matchers}

import scala.collection.mutable.ArrayBuffer

class IntegrationTest extends TestKit(ActorSystem("test-system")) with FlatSpecLike with KafkaSpec with BeforeAndAfterAll with Eventually with Matchers {

  implicit val materializer = ActorMaterializer()
  import system.dispatcher

  "KMQ" should "resend message if not committed" in {
    val bootstrapServer = s"localhost:${testKafkaConfig.kafkaPort}"
    val kmqConfig = new KmqConfig("queue", "markers", "kmq_client", "kmq_redelivery", Duration.ofSeconds(1).toMillis,
    1000)

    val consumerSettings = ConsumerSettings(system, new StringDeserializer, new StringDeserializer)
      .withBootstrapServers(bootstrapServer)
      .withGroupId(kmqConfig.getMsgConsumerGroupId)
      .withProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")

    val markerProducerSettings = ProducerSettings(system,
      new MarkerKey.MarkerKeySerializer(), new MarkerValue.MarkerValueSerializer())
      .withBootstrapServers(bootstrapServer)
      .withProperty(ProducerConfig.PARTITIONER_CLASS_CONFIG, classOf[ParititionFromMarkerKey].getName)
    val markerProducer = markerProducerSettings.createKafkaProducer()

    val random = new Random()

    lazy val processedMessages = ArrayBuffer[String]()
    lazy val receivedMessages = ArrayBuffer[String]()

    val control = Consumer.committableSource(consumerSettings, Subscriptions.topics(kmqConfig.getMsgTopic)) // 1. get messages from topic
      .map { msg =>
      ProducerMessage.Message(
        new ProducerRecord[MarkerKey, MarkerValue](kmqConfig.getMarkerTopic, MarkerKey.fromRecord(msg.record), new StartMarker(kmqConfig.getMsgTimeoutMs)), msg)
    }
      .via(Producer.flow(markerProducerSettings, markerProducer)) // 2. write the "start" marker
      .map(_.message.passThrough)
      .mapAsync(1) { msg =>
        msg.committableOffset.commitScaladsl().map(_ => msg.record) // this should be batched
      }
      .map { msg =>
        receivedMessages += msg.value
        msg
      }
      .filter(_ => random.nextInt(5) != 0)
      .map { processedMessage =>
        processedMessages += processedMessage.value
        new ProducerRecord[MarkerKey, MarkerValue](kmqConfig.getMarkerTopic, MarkerKey.fromRecord(processedMessage), EndMarker.INSTANCE)
      }
      .to(Producer.plainSink(markerProducerSettings, markerProducer)) // 5. write "end" markers
      .run()

    val redeliveryHook = RedeliveryTracker.start(new KafkaClients(bootstrapServer), kmqConfig)

    val messages = (0 to 20).map(_.toString)
    messages.foreach(msg => sendToKafka(kmqConfig.getMsgTopic,msg))

    eventually {
      receivedMessages.size should be > processedMessages.size
      processedMessages.sortBy(_.toInt).distinct shouldBe messages
    }(PatienceConfig(timeout = Span(15, Seconds)), implicitly)

    redeliveryHook.close()
    control.shutdown()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    TestKit.shutdownActorSystem(system)
  }
} 
Example 2
Source File: UndoSnackbarManager.scala    From shadowsocksr-android   with GNU General Public License v3.0 5 votes vote down vote up
package com.github.shadowsocks.widget

import android.support.design.widget.Snackbar
import android.view.View
import com.github.shadowsocks.R

import scala.collection.mutable.ArrayBuffer


class UndoSnackbarManager[T](view: View, undo: Iterator[(Int, T)] => Unit,
                             commit: Iterator[(Int, T)] => Unit = null) {
  private val recycleBin = new ArrayBuffer[(Int, T)]
  private val removedCallback = new Snackbar.Callback {
    override def onDismissed(snackbar: Snackbar, event: Int) = {
      event match {
        case Snackbar.Callback.DISMISS_EVENT_SWIPE | Snackbar.Callback.DISMISS_EVENT_MANUAL |
             Snackbar.Callback.DISMISS_EVENT_TIMEOUT =>
          if (commit != null) commit(recycleBin.iterator)
          recycleBin.clear
        case _ =>
      }
      last = null
    }
  }
  private var last: Snackbar = _

  def remove(index: Int, item: T) = {
    recycleBin.append((index, item))
    val count = recycleBin.length
    last = Snackbar
      .make(view, view.getResources.getQuantityString(R.plurals.removed, count, count: Integer), Snackbar.LENGTH_LONG)
      .setCallback(removedCallback).setAction(R.string.undo, (_ => {
      undo(recycleBin.reverseIterator)
      recycleBin.clear
    }): View.OnClickListener)
    last.show
  }

  def flush = if (last != null) last.dismiss
} 
Example 3
Source File: SinkRouteHandler.scala    From ohara   with Apache License 2.0 5 votes vote down vote up
package oharastream.ohara.shabondi.sink

import java.time.{Duration => JDuration}
import java.util.concurrent.TimeUnit

import akka.actor.ActorSystem
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, StatusCodes}
import akka.http.scaladsl.server.{ExceptionHandler, Route}
import com.typesafe.scalalogging.Logger
import oharastream.ohara.common.data.Row
import oharastream.ohara.common.util.Releasable
import oharastream.ohara.shabondi.common.{JsonSupport, RouteHandler, ShabondiUtils}
import org.apache.commons.lang3.StringUtils

import scala.collection.mutable.ArrayBuffer
import scala.compat.java8.DurationConverters._
import scala.concurrent.ExecutionContextExecutor
import scala.concurrent.duration.Duration
import spray.json.DefaultJsonProtocol._
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._

private[shabondi] object SinkRouteHandler {
  def apply(config: SinkConfig)(implicit actorSystem: ActorSystem) =
    new SinkRouteHandler(config)
}

private[shabondi] class SinkRouteHandler(config: SinkConfig)(implicit actorSystem: ActorSystem) extends RouteHandler {
  implicit private val contextExecutor: ExecutionContextExecutor = actorSystem.dispatcher

  private val log              = Logger(classOf[SinkRouteHandler])
  private[sink] val dataGroups = SinkDataGroups(config)

  def scheduleFreeIdleGroups(interval: JDuration, idleTime: JDuration): Unit =
    actorSystem.scheduler.scheduleWithFixedDelay(Duration(1, TimeUnit.SECONDS), interval.toScala) { () =>
      {
        log.trace("scheduled free group, total group: {} ", dataGroups.size)
        dataGroups.freeIdleGroup(idleTime)
      }
    }

  private val exceptionHandler = ExceptionHandler {
    case ex: Throwable =>
      log.error(ex.getMessage, ex)
      complete((StatusCodes.InternalServerError, ex.getMessage))
  }

  private def fullyPollQueue(queue: RowQueue): Seq[Row] = {
    val buffer    = ArrayBuffer.empty[Row]
    var item: Row = queue.poll()
    while (item != null) {
      buffer += item
      item = queue.poll()
    }
    buffer.toSeq
  }

  private def apiUrl = ShabondiUtils.apiUrl

  def route(): Route = handleExceptions(exceptionHandler) {
    path("groups" / Segment) { groupId =>
      get {
        if (StringUtils.isAlphanumeric(groupId)) {
          val group  = dataGroups.createIfAbsent(groupId)
          val result = fullyPollQueue(group.queue).map(row => JsonSupport.toRowData(row))
          complete(result)
        } else {
          val entity =
            HttpEntity(ContentTypes.`text/plain(UTF-8)`, "Illegal group name, only accept alpha and numeric.")
          complete(StatusCodes.NotAcceptable -> entity)
        }
      } ~ {
        complete(StatusCodes.MethodNotAllowed -> s"Unsupported method, please reference: $apiUrl")
      }
    } ~ {
      complete(StatusCodes.NotFound -> s"Please reference: $apiUrl")
    }
  }

  override def close(): Unit = {
    Releasable.close(dataGroups)
  }
} 
Example 4
Source File: CSVConverter.scala    From spark-snowflake   with Apache License 2.0 5 votes vote down vote up
package net.snowflake.spark.snowflake

import org.apache.spark.sql.types.StructType
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object CSVConverter {

  private final val delimiter = '|'
  private final val quoteChar = '"'

  private[snowflake] def convert[T: ClassTag](
    partition: Iterator[String],
    resultSchema: StructType
  ): Iterator[T] = {
    val converter = Conversions.createRowConverter[T](resultSchema)
    partition.map(s => {
      val fields = ArrayBuffer.empty[String]
      var buff = new StringBuilder

      def addField(): Unit = {
        if (buff.isEmpty) fields.append(null)
        else {
          val field = buff.toString()
          buff = new StringBuilder
          fields.append(field)
        }
      }

      var escaped = false
      var index = 0

      while (index < s.length) {
        escaped = false
        if (s(index) == quoteChar) {
          index += 1
          while (index < s.length && !(escaped && s(index) == delimiter)) {
            if (escaped) {
              escaped = false
              buff.append(s(index))
            } else if (s(index) == quoteChar) escaped = true
            else buff.append(s(index))
            index += 1
          }
          addField()
        } else {
          while (index < s.length && s(index) != delimiter) {
            buff.append(s(index))
            index += 1
          }
          addField()
        }
        index += 1
      }
      addField()
      converter(fields.toArray)
    })
  }

} 
Example 5
Source File: InterfaceTreeSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen

import com.daml.lf.data.ImmArray
import com.daml.lf.data.ImmArray.ImmArraySeq
import com.daml.lf.data.Ref.{DottedName, QualifiedName, PackageId}
import com.daml.lf.iface.{DefDataType, Interface, InterfaceType, Record, Variant}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

class InterfaceTreeSpec extends FlatSpec with Matchers {

  behavior of "InterfaceTree.bfs"

  it should "traverse an empty tree" in {
    val interfaceTree =
      InterfaceTree(Map.empty, Interface(PackageId.assertFromString("packageid"), Map.empty))
    interfaceTree.bfs(0)((x, _) => x + 1) shouldEqual 0
  }

  it should "traverse a tree with n elements in bfs order" in {
    val qualifiedName1 = QualifiedName(
      DottedName.assertFromSegments(ImmArray("foo").toSeq),
      DottedName.assertFromSegments(ImmArray("bar").toSeq))
    val record1 = InterfaceType.Normal(DefDataType(ImmArraySeq(), Record(ImmArraySeq())))
    val qualifiedName2 =
      QualifiedName(
        DottedName.assertFromSegments(ImmArray("foo").toSeq),
        DottedName.assertFromSegments(ImmArray("bar", "baz").toSeq))
    val variant1 = InterfaceType.Normal(DefDataType(ImmArraySeq(), Variant(ImmArraySeq())))
    val qualifiedName3 = QualifiedName(
      DottedName.assertFromSegments(ImmArray("foo").toSeq),
      DottedName.assertFromSegments(ImmArray("qux").toSeq))
    val record2 = InterfaceType.Normal(DefDataType(ImmArraySeq(), Record(ImmArraySeq())))
    val typeDecls =
      Map(qualifiedName1 -> record1, qualifiedName2 -> variant1, qualifiedName3 -> record2)
    val interface = new Interface(PackageId.assertFromString("packageId2"), typeDecls)
    val tree = InterfaceTree.fromInterface(interface)
    val result = tree.bfs(ArrayBuffer.empty[InterfaceType])((ab, n) =>
      n match {
        case ModuleWithContext(interface @ _, modulesLineage @ _, name @ _, module @ _) => ab
        case TypeWithContext(interface @ _, modulesLineage @ _, typesLineage @ _, name @ _, typ) =>
          ab ++= typ.typ.toList
    })
    result should contain theSameElementsInOrderAs Seq(record1, record2, variant1)
  }

  behavior of "InterfaceTree.fromInterface"

  it should "permit standalone types with multi-component names" in {
    val bazQuux =
      QualifiedName(
        DottedName.assertFromSegments(ImmArray("foo", "bar").toSeq),
        DottedName.assertFromSegments(ImmArray("baz", "quux").toSeq)
      )
    val record = InterfaceType.Normal(DefDataType(ImmArraySeq(), Record(ImmArraySeq())))

    val typeDecls = Map(bazQuux -> record)
    val interface = new Interface(PackageId.assertFromString("pkgid"), typeDecls)
    val tree = InterfaceTree.fromInterface(interface)
    val result = tree.bfs(ArrayBuffer.empty[InterfaceType])((types, n) =>
      n match {
        case _: ModuleWithContext => types
        case TypeWithContext(_, _, _, _, tpe) =>
          types ++= tpe.typ.toList
    })
    result.toList shouldBe List(record)
  }

} 
Example 6
Source File: SpearmanCorrelation.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat.correlation

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.rdd.RDD


  override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
    // ((columnIndex, value), rowUid)
    val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
      vec.toArray.view.zipWithIndex.map { case (v, j) =>
        ((j, v), uid)
      }
    }
    // global sort by (columnIndex, value)
    val sorted = colBased.sortByKey()
    // assign global ranks (using average ranks for tied values)
    val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
      var preCol = -1
      var preVal = Double.NaN
      var startRank = -1.0
      var cachedUids = ArrayBuffer.empty[Long]
      val flush: () => Iterable[(Long, (Int, Double))] = () => {
        val averageRank = startRank + (cachedUids.size - 1) / 2.0
        val output = cachedUids.map { uid =>
          (uid, (preCol, averageRank))
        }
        cachedUids.clear()
        output
      }
      iter.flatMap { case (((j, v), uid), rank) =>
        // If we see a new value or cachedUids is too big, we flush ids with their average rank.
        if (j != preCol || v != preVal || cachedUids.size >= 10000000) {
          val output = flush()
          preCol = j
          preVal = v
          startRank = rank
          cachedUids += uid
          output
        } else {
          cachedUids += uid
          Iterator.empty
        }
      } ++ flush()
    }
    // Replace values in the input matrix by their ranks compared with values in the same column.
    // Note that shifting all ranks in a column by a constant value doesn't affect result.
    val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
      // sort by column index and then convert values to a vector
      Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
    }
    PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
  }
} 
Example 7
Source File: KPLBasedKinesisTestUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kinesis

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
  override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
    if (!aggregate) {
      new SimpleDataGenerator(kinesisClient)
    } else {
      new KPLDataGenerator(regionName)
    }
  }
}


private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {

  private lazy val producer: KPLProducer = {
    val conf = new KinesisProducerConfiguration()
      .setRecordMaxBufferedTime(1000)
      .setMaxConnections(1)
      .setRegion(regionName)
      .setMetricsLevel("none")

    new KPLProducer(conf)
  }

  override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
    val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
    data.foreach { num =>
      val str = num.toString
      val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8))
      val future = producer.addUserRecord(streamName, str, data)
      val kinesisCallBack = new FutureCallback[UserRecordResult]() {
        override def onFailure(t: Throwable): Unit = {} // do nothing

        override def onSuccess(result: UserRecordResult): Unit = {
          val shardId = result.getShardId
          val seqNumber = result.getSequenceNumber()
          val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
            new ArrayBuffer[(Int, String)]())
          sentSeqNumbers += ((num, seqNumber))
        }
      }
      Futures.addCallback(future, kinesisCallBack)
    }
    producer.flushSync()
    shardIdToSeqNumbers.toMap
  }
} 
Example 8
Source File: Exchange.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 9
Source File: subquery.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 10
Source File: ApplicationMasterArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.util.{IntParam, MemoryParam}

class ApplicationMasterArguments(val args: Array[String]) {
  var userJar: String = null
  var userClass: String = null
  var primaryPyFile: String = null
  var primaryRFile: String = null
  var userArgs: Seq[String] = Nil
  var propertiesFile: String = null

  parseArgs(args.toList)

  private def parseArgs(inputArgs: List[String]): Unit = {
    val userArgsBuffer = new ArrayBuffer[String]()

    var args = inputArgs

    while (!args.isEmpty) {
      // --num-workers, --worker-memory, and --worker-cores are deprecated since 1.0,
      // the properties with executor in their names are preferred.
      args match {
        case ("--jar") :: value :: tail =>
          userJar = value
          args = tail

        case ("--class") :: value :: tail =>
          userClass = value
          args = tail

        case ("--primary-py-file") :: value :: tail =>
          primaryPyFile = value
          args = tail

        case ("--primary-r-file") :: value :: tail =>
          primaryRFile = value
          args = tail

        case ("--arg") :: value :: tail =>
          userArgsBuffer += value
          args = tail

        case ("--properties-file") :: value :: tail =>
          propertiesFile = value
          args = tail

        case _ =>
          printUsageAndExit(1, args)
      }
    }

    if (primaryPyFile != null && primaryRFile != null) {
      // scalastyle:off println
      System.err.println("Cannot have primary-py-file and primary-r-file at the same time")
      // scalastyle:on println
      System.exit(-1)
    }

    userArgs = userArgsBuffer.toList
  }

  def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
    // scalastyle:off println
    if (unknownParam != null) {
      System.err.println("Unknown/unsupported param " + unknownParam)
    }
    System.err.println("""
      |Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options]
      |Options:
      |  --jar JAR_PATH       Path to your application's JAR file
      |  --class CLASS_NAME   Name of your application's main class
      |  --primary-py-file    A main Python file
      |  --primary-r-file     A main R file
      |  --arg ARG            Argument to be passed to your application's main class.
      |                       Multiple invocations are possible, each will be passed in order.
      |  --properties-file FILE Path to a custom Spark properties file.
      """.stripMargin)
    // scalastyle:on println
    System.exit(exitCode)
  }
}

object ApplicationMasterArguments {
  val DEFAULT_NUMBER_EXECUTORS = 2
} 
Example 11
Source File: ClientArguments.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.mutable.ArrayBuffer

// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
private[spark] class ClientArguments(args: Array[String]) {

  var userJar: String = null
  var userClass: String = null
  var primaryPyFile: String = null
  var primaryRFile: String = null
  var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()

  parseArgs(args.toList)

  private def parseArgs(inputArgs: List[String]): Unit = {
    var args = inputArgs

    while (!args.isEmpty) {
      args match {
        case ("--jar") :: value :: tail =>
          userJar = value
          args = tail

        case ("--class") :: value :: tail =>
          userClass = value
          args = tail

        case ("--primary-py-file") :: value :: tail =>
          primaryPyFile = value
          args = tail

        case ("--primary-r-file") :: value :: tail =>
          primaryRFile = value
          args = tail

        case ("--arg") :: value :: tail =>
          userArgs += value
          args = tail

        case Nil =>

        case _ =>
          throw new IllegalArgumentException(getUsageMessage(args))
      }
    }

    if (primaryPyFile != null && primaryRFile != null) {
      throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" +
        " at the same time")
    }
  }

  private def getUsageMessage(unknownParam: List[String] = null): String = {
    val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
    message +
      s"""
      |Usage: org.apache.spark.deploy.yarn.Client [options]
      |Options:
      |  --jar JAR_PATH           Path to your application's JAR file (required in yarn-cluster
      |                           mode)
      |  --class CLASS_NAME       Name of your application's main class (required)
      |  --primary-py-file        A main Python file
      |  --primary-r-file         A main R file
      |  --arg ARG                Argument to be passed to your application's main class.
      |                           Multiple invocations are possible, each will be passed in order.
      """.stripMargin
  }
} 
Example 12
Source File: YarnClientSchedulerBackend.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.yarn.api.records.YarnApplicationState

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil}
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkAppHandle
import org.apache.spark.scheduler.TaskSchedulerImpl

private[spark] class YarnClientSchedulerBackend(
    scheduler: TaskSchedulerImpl,
    sc: SparkContext)
  extends YarnSchedulerBackend(scheduler, sc)
  with Logging {

  private var client: Client = null
  private var monitorThread: MonitorThread = null

  
  override def stop() {
    assert(client != null, "Attempted to stop this scheduler before starting it!")
    if (monitorThread != null) {
      monitorThread.stopMonitor()
    }

    // Report a final state to the launcher if one is connected. This is needed since in client
    // mode this backend doesn't let the app monitor loop run to completion, so it does not report
    // the final state itself.
    //
    // Note: there's not enough information at this point to provide a better final state,
    // so assume the application was successful.
    client.reportLauncherState(SparkAppHandle.State.FINISHED)

    super.stop()
    YarnSparkHadoopUtil.get.stopCredentialUpdater()
    client.stop()
    logInfo("Stopped")
  }

} 
Example 13
Source File: UnionDStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}

private[streaming]
class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
  extends DStream[T](parents.head.ssc) {

  require(parents.length > 0, "List of DStreams to union is empty")
  require(parents.map(_.ssc).distinct.length == 1, "Some of the DStreams have different contexts")
  require(parents.map(_.slideDuration).distinct.length == 1,
    "Some of the DStreams have different slide durations")

  override def dependencies: List[DStream[_]] = parents.toList

  override def slideDuration: Duration = parents.head.slideDuration

  override def compute(validTime: Time): Option[RDD[T]] = {
    val rdds = new ArrayBuffer[RDD[T]]()
    parents.map(_.getOrCompute(validTime)).foreach {
      case Some(rdd) => rdds += rdd
      case None => throw new SparkException("Could not generate RDD from a parent for unifying at" +
        s" time $validTime")
    }
    if (rdds.nonEmpty) {
      Some(ssc.sc.union(rdds))
    } else {
      None
    }
  }
} 
Example 14
Source File: QueueInputDStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable.{ArrayBuffer, Queue}
import scala.reflect.ClassTag

import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming.{StreamingContext, Time}

private[streaming]
class QueueInputDStream[T: ClassTag](
    ssc: StreamingContext,
    val queue: Queue[RDD[T]],
    oneAtATime: Boolean,
    defaultRDD: RDD[T]
  ) extends InputDStream[T](ssc) {

  override def start() { }

  override def stop() { }

  private def readObject(in: ObjectInputStream): Unit = {
    throw new NotSerializableException("queueStream doesn't support checkpointing. " +
      "Please don't use queueStream when checkpointing is enabled.")
  }

  private def writeObject(oos: ObjectOutputStream): Unit = {
    logWarning("queueStream doesn't support checkpointing")
  }

  override def compute(validTime: Time): Option[RDD[T]] = {
    val buffer = new ArrayBuffer[RDD[T]]()
    queue.synchronized {
      if (oneAtATime && queue.nonEmpty) {
        buffer += queue.dequeue()
      } else {
        buffer ++= queue
        queue.clear()
      }
    }
    if (buffer.nonEmpty) {
      if (oneAtATime) {
        Some(buffer.head)
      } else {
        Some(new UnionRDD(context.sc, buffer.toSeq))
      }
    } else if (defaultRDD != null) {
      Some(defaultRDD)
    } else {
      Some(ssc.sparkContext.emptyRDD)
    }
  }

} 
Example 15
Source File: LocalSparkCluster.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkConf
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils


    for (workerNum <- 1 to numWorkers) {
      val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
        memoryPerWorker, masters, null, Some(workerNum), _conf)
      workerRpcEnvs += workerEnv
    }

    masters
  }

  def stop() {
    logInfo("Shutting down local Spark cluster.")
    // Stop the workers before the master so they don't get upset that it disconnected
    workerRpcEnvs.foreach(_.shutdown())
    masterRpcEnvs.foreach(_.shutdown())
    workerRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.clear()
    workerRpcEnvs.clear()
  }
} 
Example 16
Source File: TaskResult.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}

// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]


  def value(resultSer: SerializerInstance = null): T = {
    if (valueObjectDeserialized) {
      valueObject
    } else {
      // This should not run when holding a lock because it may cost dozens of seconds for a large
      // value
      val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
      valueObject = ser.deserialize(valueBytes)
      valueObjectDeserialized = true
      valueObject
    }
  }
} 
Example 17
Source File: Schedulable.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode


private[spark] trait Schedulable {
  var parent: Pool
  // child queues
  def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
  def schedulingMode: SchedulingMode
  def weight: Int
  def minShare: Int
  def runningTasks: Int
  def priority: Int
  def stageId: Int
  def name: String

  def addSchedulable(schedulable: Schedulable): Unit
  def removeSchedulable(schedulable: Schedulable): Unit
  def getSchedulableByName(name: String): Schedulable
  def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit
  def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean
  def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
} 
Example 18
Source File: ChunkedByteBufferOutputStream.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.io

import java.io.OutputStream
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.storage.StorageUtils


  private[this] var position = chunkSize
  private[this] var _size = 0
  private[this] var closed: Boolean = false

  def size: Long = _size

  override def close(): Unit = {
    if (!closed) {
      super.close()
      closed = true
    }
  }

  override def write(b: Int): Unit = {
    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
    allocateNewChunkIfNeeded()
    chunks(lastChunkIndex).put(b.toByte)
    position += 1
    _size += 1
  }

  override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
    var written = 0
    while (written < len) {
      allocateNewChunkIfNeeded()
      val thisBatch = math.min(chunkSize - position, len - written)
      chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
      written += thisBatch
      position += thisBatch
    }
    _size += len
  }

  @inline
  private def allocateNewChunkIfNeeded(): Unit = {
    if (position == chunkSize) {
      chunks += allocator(chunkSize)
      lastChunkIndex += 1
      position = 0
    }
  }

  def toChunkedByteBuffer: ChunkedByteBuffer = {
    require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
    require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
    toChunkedByteBufferWasCalled = true
    if (lastChunkIndex == -1) {
      new ChunkedByteBuffer(Array.empty[ByteBuffer])
    } else {
      // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
      // An alternative would have been returning an array of ByteBuffers, with the last buffer
      // bounded to only the last chunk's position. However, given our use case in Spark (to put
      // the chunks in block manager), only limiting the view bound of the buffer would still
      // require the block manager to store the whole chunk.
      val ret = new Array[ByteBuffer](chunks.size)
      for (i <- 0 until chunks.size - 1) {
        ret(i) = chunks(i)
        ret(i).flip()
      }
      if (position == chunkSize) {
        ret(lastChunkIndex) = chunks(lastChunkIndex)
        ret(lastChunkIndex).flip()
      } else {
        ret(lastChunkIndex) = allocator(position)
        chunks(lastChunkIndex).flip()
        ret(lastChunkIndex).put(chunks(lastChunkIndex))
        ret(lastChunkIndex).flip()
        StorageUtils.dispose(chunks(lastChunkIndex))
      }
      new ChunkedByteBuffer(ret)
    }
  }
} 
Example 19
Source File: SubtractedRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.util.{HashMap => JHashMap}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.Dependency
import org.apache.spark.OneToOneDependency
import org.apache.spark.Partition
import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext


private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
    @transient var rdd1: RDD[_ <: Product2[K, V]],
    @transient var rdd2: RDD[_ <: Product2[K, W]],
    part: Partitioner)
  extends RDD[(K, V)](rdd1.context, Nil) {


  override def getDependencies: Seq[Dependency[_]] = {
    def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]])
      : Dependency[_] = {
      if (rdd.partitioner == Some(part)) {
        logDebug("Adding one-to-one dependency with " + rdd)
        new OneToOneDependency(rdd)
      } else {
        logDebug("Adding shuffle dependency with " + rdd)
        new ShuffleDependency[T1, T2, Any](rdd, part)
      }
    }
    Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2))
  }

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](part.numPartitions)
    for (i <- 0 until array.length) {
      // Each CoGroupPartition will depend on rdd1 and rdd2
      array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
        dependencies(j) match {
          case s: ShuffleDependency[_, _, _] =>
            None
          case _ =>
            Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
        }
      }.toArray)
    }
    array
  }

  override val partitioner = Some(part)

  override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
    val partition = p.asInstanceOf[CoGroupPartition]
    val map = new JHashMap[K, ArrayBuffer[V]]
    def getSeq(k: K): ArrayBuffer[V] = {
      val seq = map.get(k)
      if (seq != null) {
        seq
      } else {
        val seq = new ArrayBuffer[V]()
        map.put(k, seq)
        seq
      }
    }
    def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = {
      dependencies(depNum) match {
        case oneToOneDependency: OneToOneDependency[_] =>
          val dependencyPartition = partition.narrowDeps(depNum).get.split
          oneToOneDependency.rdd.iterator(dependencyPartition, context)
            .asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

        case shuffleDependency: ShuffleDependency[_, _, _] =>
          val iter = SparkEnv.get.shuffleManager
            .getReader(
              shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
            .read()
          iter.foreach(op)
      }
    }

    // the first dep is rdd1; add all values to the map
    integrate(0, t => getSeq(t._1) += t._2)
    // the second dep is rdd2; remove all of its keys
    integrate(1, t => map.remove(t._1))
    map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten
  }

  override def clearDependencies() {
    super.clearDependencies()
    rdd1 = null
    rdd2 = null
  }

} 
Example 20
Source File: UnionRDD.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
      parArray
    } else {
      rdds
    }
    val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 21
Source File: TaskContextImpl.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.util._

private[spark] class TaskContextImpl(
    val stageId: Int,
    val partitionId: Int,
    override val taskAttemptId: Long,
    override val attemptNumber: Int,
    var _taskMemoryManager: TaskMemoryManager,
    localProperties: Properties,
    @transient private val metricsSystem: MetricsSystem,
    // The default value is only used in tests.
    override val taskMetrics: TaskMetrics = TaskMetrics.empty,
    var batchId: Int = 0)
  extends TaskContext
  with Logging {

  
  private[spark] def markInterrupted(): Unit = {
    interrupted = true
  }

  override def isCompleted(): Boolean = completed

  override def isRunningLocally(): Boolean = false

  override def isInterrupted(): Boolean = interrupted

  override def getLocalProperty(key: String): String = localProperties.getProperty(key)

  override def getMetricsSources(sourceName: String): Seq[Source] =
    metricsSystem.getSourcesByName(sourceName)

  private[spark] override def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
    taskMetrics.registerAccumulator(a)
  }

} 
Example 22
Source File: TimeStampedHashMapSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.apache.spark.SparkFunSuite

class TimeStampedHashMapSuite extends SparkFunSuite {

  // Test the testMap function - a Scala HashMap should obviously pass
  testMap(new mutable.HashMap[String, String]())

  // Test TimeStampedHashMap basic functionality
  testMap(new TimeStampedHashMap[String, String]())
  testMapThreadSafety(new TimeStampedHashMap[String, String]())

  test("TimeStampedHashMap - clearing by timestamp") {
    // clearing by insertion time
    val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false)
    map("k1") = "v1"
    assert(map("k1") === "v1")
    Thread.sleep(10)
    val threshTime = System.currentTimeMillis
    assert(map.getTimestamp("k1").isDefined)
    assert(map.getTimestamp("k1").get < threshTime)
    map.clearOldValues(threshTime)
    assert(map.get("k1") === None)

    // clearing by modification time
    val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true)
    map1("k1") = "v1"
    map1("k2") = "v2"
    assert(map1("k1") === "v1")
    Thread.sleep(10)
    val threshTime1 = System.currentTimeMillis
    Thread.sleep(10)
    assert(map1("k2") === "v2")     // access k2 to update its access time to > threshTime
    assert(map1.getTimestamp("k1").isDefined)
    assert(map1.getTimestamp("k1").get < threshTime1)
    assert(map1.getTimestamp("k2").isDefined)
    assert(map1.getTimestamp("k2").get >= threshTime1)
    map1.clearOldValues(threshTime1) // should only clear k1
    assert(map1.get("k1") === None)
    assert(map1.get("k2").isDefined)
  }

  
  def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) {
    def newMap() = hashMapConstructor
    val name = newMap().getClass.getSimpleName
    val testMap = newMap()
    @volatile var error = false

    def getRandomKey(m: mutable.Map[String, String]): Option[String] = {
      val keys = testMap.keysIterator.toSeq
      if (keys.nonEmpty) {
        Some(keys(Random.nextInt(keys.size)))
      } else {
        None
      }
    }

    val threads = (1 to 25).map(i => new Thread() {
      override def run() {
        try {
          for (j <- 1 to 1000) {
            Random.nextInt(3) match {
              case 0 =>
                testMap(Random.nextString(10)) = Random.nextDouble().toString // put
              case 1 =>
                getRandomKey(testMap).map(testMap.get) // get
              case 2 =>
                getRandomKey(testMap).map(testMap.remove) // remove
            }
          }
        } catch {
          case t: Throwable =>
            error = true
            throw t
        }
      }
    })

    test(name + " - threading safety test")  {
      threads.foreach(_.start())
      threads.foreach(_.join())
      assert(!error)
    }
  }
} 
Example 23
Source File: Predict.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.lenetLocal
import com.intel.analytics.bigdl.dataset.image.{BytesToGreyImg, GreyImgNormalizer, GreyImgToSample}
import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.utils.Engine
import com.intel.analytics.bigdl.dataset.Sample
import com.intel.analytics.bigdl.optim.LocalPredictor
import org.apache.log4j.{Level, Logger}

import scala.collection.mutable.ArrayBuffer

object Predict {
  Logger.getLogger("org").setLevel(Level.ERROR)
  Logger.getLogger("akka").setLevel(Level.ERROR)
  Logger.getLogger("breeze").setLevel(Level.ERROR)


  import Utils._

  def main(args: Array[String]): Unit = {
    predictParser.parse(args, new PredictParams()).foreach { param =>

      System.setProperty("bigdl.localMode", "true")
      System.setProperty("bigdl.coreNumber", (param.coreNumber.toString))
      Engine.init

      val validationData = param.folder + "/t10k-images-idx3-ubyte"
      val validationLabel = param.folder + "/t10k-labels-idx1-ubyte"

      val rawData = load(validationData, validationLabel)
      val iter = rawData.iterator
      val sampleIter = GreyImgToSample()(
          GreyImgNormalizer(trainMean, trainStd)(
          BytesToGreyImg(28, 28)(iter)))
      var samplesBuffer = ArrayBuffer[Sample[Float]]()
      while (sampleIter.hasNext) {
        val elem = sampleIter.next().clone()
        samplesBuffer += elem
      }
      val samples = samplesBuffer.toArray

      val model = Module.load[Float](param.model)
      val localPredictor = LocalPredictor(model)
      val result = localPredictor.predict(samples)
      val result_class = localPredictor.predictClass(samples)
      result_class.foreach(r => println(s"${r}"))
    }
  }
} 
Example 24
Source File: BatchSampler.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.transform.vision.image.label.roi

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.transform.vision.image.util.{BboxUtil, BoundingBox}
import com.intel.analytics.bigdl.utils.RandomGenerator._

import scala.collection.mutable.ArrayBuffer



  def generateBatchSamples(label: RoiLabel, batchSamplers: Array[BatchSampler],
    sampledBoxes: ArrayBuffer[BoundingBox]): Unit = {
    sampledBoxes.clear()
    var i = 0
    val unitBox = BoundingBox(0, 0, 1, 1)
    while (i < batchSamplers.length) {
      batchSamplers(i).sample(unitBox, label, sampledBoxes)
      i += 1
    }
  }
} 
Example 25
Source File: RandomSampler.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.transform.vision.image.label.roi

import com.intel.analytics.bigdl.transform.vision.image.{FeatureTransformer, ImageFeature}
import com.intel.analytics.bigdl.transform.vision.image.augmentation.Crop
import com.intel.analytics.bigdl.transform.vision.image.util.{BoundingBox}
import com.intel.analytics.bigdl.utils.RandomGenerator._
import org.opencv.core.Mat

import scala.collection.mutable.ArrayBuffer


class RandomSampler extends Crop {
  // random cropping samplers
  val batchSamplers = Array(
    new BatchSampler(maxTrials = 1),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.1)),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.3)),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.5)),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.7)),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.9)),
    new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      maxOverlap = Some(1.0)))

  def generateRoi(feature: ImageFeature): BoundingBox = {
    val roiLabel = feature(ImageFeature.label).asInstanceOf[RoiLabel]
    val boxesBuffer = new ArrayBuffer[BoundingBox]()
    BatchSampler.generateBatchSamples(roiLabel,
      batchSamplers, boxesBuffer)

    // randomly pick up one as input data
    if (boxesBuffer.nonEmpty) {
      // Randomly pick a sampled bbox and crop the expand_datum.
      val index = (RNG.uniform(0, 1) * boxesBuffer.length).toInt
      boxesBuffer(index)
    } else {
      BoundingBox(0, 0, 1, 1)
    }
  }
}

object RandomSampler {
  def apply(): FeatureTransformer = {
    new RandomSampler() -> RoiProject()
  }
} 
Example 26
Source File: RoiTransformer.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.transform.vision.image.label.roi

import com.intel.analytics.bigdl.transform.vision.image.util.{BboxUtil, BoundingBox}
import com.intel.analytics.bigdl.transform.vision.image.{FeatureTransformer, ImageFeature}

import scala.collection.mutable.ArrayBuffer


case class RoiProject(needMeetCenterConstraint: Boolean = true) extends FeatureTransformer {
  val transformedAnnot = new ArrayBuffer[BoundingBox]()
  override def transformMat(feature: ImageFeature): Unit = {
    val imageBoundary = feature[BoundingBox](ImageFeature.boundingBox)
    if (!imageBoundary.normalized) {
      imageBoundary.scaleBox(1.0f / feature.getHeight(), 1f / feature.getWidth(), imageBoundary)
    }
    val target = feature[RoiLabel](ImageFeature.label)
    transformedAnnot.clear()
    // Transform the annotation according to bounding box.
    var i = 1
    while (i <= target.size()) {
      val gtBoxes = BoundingBox(target.bboxes.valueAt(i, 1),
        target.bboxes.valueAt(i, 2),
        target.bboxes.valueAt(i, 3),
        target.bboxes.valueAt(i, 4))
      if (!needMeetCenterConstraint ||
        imageBoundary.meetEmitCenterConstraint(gtBoxes)) {
        val transformedBox = new BoundingBox()
        if (imageBoundary.projectBbox(gtBoxes, transformedBox)) {
          transformedBox.setLabel(target.classes.valueAt(1, i))
          transformedBox.setDifficult(target.classes.valueAt(2, i))
          transformedAnnot.append(transformedBox)
        }
      }
      i += 1
    }
    // write the transformed annotation back to target
    target.bboxes.resize(transformedAnnot.length, 4)
    target.classes.resize(2, transformedAnnot.length)

    i = 1
    while (i <= transformedAnnot.length) {
      target.bboxes.setValue(i, 1, transformedAnnot(i - 1).x1)
      target.bboxes.setValue(i, 2, transformedAnnot(i - 1).y1)
      target.bboxes.setValue(i, 3, transformedAnnot(i - 1).x2)
      target.bboxes.setValue(i, 4, transformedAnnot(i - 1).y2)
      target.classes.setValue(1, i, transformedAnnot(i - 1).label)
      target.classes.setValue(2, i, transformedAnnot(i - 1).difficult)
      i += 1
    }
  }
} 
Example 27
Source File: Mean.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.ByteOrder

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.Sequential
import com.intel.analytics.bigdl.nn.tf.Mean
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.tf.Context
import org.tensorflow.framework.{DataType, NodeDef}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class Mean extends TensorflowOpsLoader {

  import Utils._

  override def build[T: ClassTag](nodeDef: NodeDef, byteOrder: ByteOrder
    , context: Context[T])(implicit ev: TensorNumeric[T]): Module[T] = {
    val attr = nodeDef.getAttrMap
    val dataType = getType(attr, "T")
    val squeeze = !getBoolean(attr, "keep_dims")
    val dt = dataType match {
      case DataType.DT_INT8 =>
        "Int"
      case DataType.DT_INT16 =>
        "Int"
      case DataType.DT_UINT8 =>
        "Int"
      case DataType.DT_UINT16 =>
        "Int"
      case DataType.DT_INT32 =>
        "Int"
      case DataType.DT_INT64 =>
        "Long"
      case DataType.DT_FLOAT =>
        "Float"
      case DataType.DT_DOUBLE =>
        "Double"
      case _ => throw new UnsupportedOperationException("Data Type: " + dataType +
        " is not Unsupported yet.")
    }
    new MeanLoadTF[T](dt, squeeze)
  }
}

class MeanLoadTF[T: ClassTag](val dataType: String,
                              val squeeze: Boolean)(implicit ev: TensorNumeric[T])
  extends Adapter[T](Array(2)) {
  override def build(tensorArrays: Array[Tensor[_]]): AbstractModule[Activity, Activity, T] = {
    val dims = tensorArrays(0).asInstanceOf[Tensor[Int]]
    val dim = ArrayBuffer[Int]()
    val mean = Sequential[T]()
    for (i <- 1 to dims.size(1)) {
      dim += dims.valueAt(i) + 1
    }
    dataType match {
      case "Int" =>
        dim.foreach(i => mean.add(Mean[T, Int](i, squeeze = squeeze)))
      case "Long" =>
        dim.foreach(i => mean.add(Mean[T, Long](i, squeeze = squeeze)))
      case "Float" =>
        dim.foreach(i => mean.add(Mean[T, Float](i, squeeze = squeeze)))
      case "Double" =>
        dim.foreach(i => mean.add(Mean[T, Double](i, squeeze = squeeze)))
      case _ => throw new UnsupportedOperationException("Data Type: " + dataType +
        " is not Unsupported yet.")
    }
    mean
  }
} 
Example 28
Source File: Transpose.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.ByteOrder

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.{Contiguous, Sequential, Transpose => TransposeLayer}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.tf.Context
import org.tensorflow.framework.NodeDef

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class Transpose extends TensorflowOpsLoader {

  import Utils._

  override def build[T: ClassTag](nodeDef: NodeDef, byteOrder: ByteOrder
  , context: Context[T])(implicit ev: TensorNumeric[T]): Module[T] = {
    new TransposeLoadTF[T]()
  }
}

object TransposeLoadTF {

  def permToPair(perm: Array[Int]): Array[(Int, Int)] = {
    val numToRank = perm.zipWithIndex.toMap
    val arr = perm.indices.toArray
    val pairs = ArrayBuffer[(Int, Int)]()

    def sort(arr: Array[Int], low: Int, high: Int): Unit = {
      var i = low
      var j = high
      val pivot = arr(low + (high - low)/2)

      while (i <= j) {
        while (arr(i) < pivot) i += 1
        while (arr(j) > pivot) j -= 1

        if (i <= j) {
          exchangeNumbers(arr, i, j)
          i += 1
          j -= 1
        }
      }

      if (low < j) sort(arr, low, j)
      if (i < high) sort(arr, i, high)
    }

    def exchangeNumbers(arr: Array[Int], i: Int, j: Int): Unit = {
      val temp = arr(i)
      arr(i) = arr(j)
      arr(j) = temp
      pairs += ((i, j))
    }

    sort(arr.map(numToRank), 0, arr.length-1)

    pairs.filter(pair => pair._1 != pair._2).toArray
  }
}

class TransposeLoadTF[T: ClassTag]()(implicit ev: TensorNumeric[T]) extends Adapter[T](Array(2)) {
  import TransposeLoadTF._

  override def build(tensorArrays: Array[Tensor[_]]): AbstractModule[Activity, Activity, T] = {
    val perm = tensorArrays(0).asInstanceOf[Tensor[Int]].storage().array()
    val paris = permToPair(perm)
    val layer = Sequential()
    layer.add(TransposeLayer[T](paris.map(x => (x._1 + 1, x._2 + 1))))
    layer.add(Contiguous())
    layer
  }
} 
Example 29
Source File: Pad.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.ByteOrder

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.{Padding, Sequential}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.tf.{Context, TFUtils}
import org.tensorflow.framework.NodeDef

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class Pad extends TensorflowOpsLoader {

  import Utils._

  override def build[T: ClassTag](nodeDef: NodeDef, byteOrder: ByteOrder,
    context: Context[T])(implicit ev: TensorNumeric[T]): Module[T] = {
    new PadLoadTF[T]()
  }
}

class PadLoadTF[T: ClassTag]()(implicit ev: TensorNumeric[T]) extends Adapter[T](Array(2)) {
  override def build(tensorArrays: Array[Tensor[_]]): AbstractModule[Activity, Activity, T] = {
    val paddings = tensorArrays(0).asInstanceOf[Tensor[Int]]
    val pad = ArrayBuffer[Int]()
    val padding = Sequential[T]()

    for(dim <- 1 to paddings.size(1)) {
      if (paddings.valueAt(dim, 1) != 0 || paddings.valueAt(dim, 2) != 0 ) {
        if (paddings(Array(dim, 1)) != 0) {
          padding.add(Padding[T](dim, -paddings.valueAt(dim, 1), 4))
        }
        if (paddings(Array(dim, 2)) != 0) {
          padding.add(Padding[T](dim, paddings.valueAt(dim, 2), 4))
        }
      }
    }

    padding
  }
} 
Example 30
Source File: IRConverter.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.intermediate

import com.intel.analytics.bigdl.nn.Graph
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.mkldnn._
import com.intel.analytics.bigdl.tensor.{FloatType, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.{Module, utils}
import com.intel.analytics.bigdl.utils.{Engine, MklBlas, MklDnn, Node}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


private[bigdl] class IRConverter[T: ClassTag](IRgraph: IRGraph[T])(implicit ev: TensorNumeric[T]) {
  private val allNodes = new ArrayBuffer[Node[IRElement[T]]]
  private val irInputs = IRgraph.inputs.toArray
  private val irOutputs = IRgraph.outputs.toArray

  init()
  private def init() : Unit = {
    getNodes(irInputs, allNodes)
    // reminder: some output nodes may not be searched from inputs
    irOutputs.foreach(node => {
      if (!allNodes.contains(node)) allNodes.append(node)
    })
  }


  private def getNodes(inputs: Seq[Node[IRElement[T]]],
                       nodesBuffer: ArrayBuffer[Node[IRElement[T]]]): Unit = {
    if (inputs.length == 0) return
    inputs.foreach(node => {
      if (!nodesBuffer.contains(node)) {
        nodesBuffer.append(node)
        getNodes(node.nextNodes, nodesBuffer)
      }
    })
  }

  
  def toGraph() : Graph[T] = {
    if (utils.Engine.getEngineType() == MklBlas) {
      require(IRToBlas[T].convertingCheck(allNodes.toArray),
        "IR graph can not be converted to Blas layer")
      toBlasGraph()
    } else if (utils.Engine.getEngineType() == MklDnn) {
      require(ev.getType() == FloatType, "Mkldnn engine only supports float data")
      require(IRToDnn[Float].convertingCheck(
        allNodes.toArray.asInstanceOf[Array[Node[IRElement[Float]]]]),
        "IR graph can not be converted to Dnn layer")
      toDnnGraph()
    } else throw new UnsupportedOperationException(
      s"Only support engineType mkldnn/mklblas, but get ${Engine.getEngineType()}")
  }

  private def toDnnGraph(): Graph[T] = {
    val nodeMap = IRToDnn[Float].convert(
      allNodes.toArray.asInstanceOf[Array[Node[IRElement[Float]]]])
    val inputs = irInputs.map(
      n => nodeMap.get(n.asInstanceOf[Node[IRElement[Float]]]).get)
    val outputs = irOutputs.map(
      n => nodeMap.get(n.asInstanceOf[Node[IRElement[Float]]]).get)

    // add input node for dnn graph
    val realInputs = inputs.map(n => {
      val node = new Node[Module[Float]](new InputWrapper())
      n.from(node)
      node
    })

    // add output node for graph
    val realOutputs = outputs.zipWithIndex.map {
      case (model: Node[Module[Float]], index: Int) =>
        val node = if (model.element.isInstanceOf[BlasWrapper]) {
          model
        } else {
          model.add(new Node[Module[Float]](Output(IRgraph.outputFormats(index))))
        }
        node
    }

    DnnGraph(realInputs, realOutputs,
      IRgraph.variables.asInstanceOf[Option[(Array[Tensor[Float]], Array[Tensor[Float]])]],
      IRgraph.generateBackward).asInstanceOf[Graph[T]]
  }

  private def toBlasGraph(): Graph[T] = {
    val nodeMap = IRToBlas[T].convert(allNodes.toArray)
    val inputs = irInputs.map(n => nodeMap.get(n).get)
    val outputs = irOutputs.map(n => nodeMap.get(n).get)

    Graph.dynamic(inputs, outputs, IRgraph.variables, IRgraph.generateBackward)
  }
} 
Example 31
Source File: FileReader.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.visualization.tensorboard

import java.io.{BufferedInputStream}
import java.nio.ByteBuffer

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.tensorflow.util.Event

import scala.collection.mutable.ArrayBuffer
import scala.util.matching.Regex

private[bigdl] object FileReader {
  val fileNameRegex = """bigdl.tfevents.*""".r

  
  def readScalar(file: Path, tag: String, fs: FileSystem): Array[(Long, Float, Double)] = {
    require(fs.isFile(file), s"FileReader: ${file} should be a file")
    val bis = new BufferedInputStream(fs.open(file))
    val longBuffer = new Array[Byte](8)
    val crcBuffer = new Array[Byte](4)
    val bf = new ArrayBuffer[(Long, Float, Double)]
    while (bis.read(longBuffer) > 0) {
      val l = ByteBuffer.wrap(longBuffer.reverse).getLong()
      bis.read(crcBuffer)
      // TODO: checksum
      //      val crc1 = ByteBuffer.wrap(crcBuffer.reverse).getInt()
      val eventBuffer = new Array[Byte](l.toInt)
      bis.read(eventBuffer)
      val e = Event.parseFrom(eventBuffer)
      if (e.getSummary.getValueCount == 1 &&
        tag.equals(e.getSummary.getValue(0).getTag())) {
        bf.append((e.getStep, e.getSummary.getValue(0).getSimpleValue,
          e.getWallTime))
      }
      bis.read(crcBuffer)
      //      val crc2 = ByteBuffer.wrap(crcBuffer.reverse).getInt()
    }
    bis.close()
    bf.toArray.sortWith(_._1 < _._1)
  }
} 
Example 32
Source File: Permute.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.keras

import com.intel.analytics.bigdl.nn.Transpose
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Shape

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


class Permute[T: ClassTag](
   val dims: Array[Int],
   val inputShape: Shape = null)(implicit ev: TensorNumeric[T])
  extends KerasLayer[Tensor[T], Tensor[T], T](KerasLayer.addBatch(inputShape)) {

  private def permToPair(perm: Array[Int]): Array[(Int, Int)] = {
    val numToRank = perm.zipWithIndex.toMap
    val arr = perm.indices.toArray
    val pairs = ArrayBuffer[(Int, Int)]()

    def sort(arr: Array[Int], low: Int, high: Int): Unit = {
      var i = low
      var j = high
      val pivot = arr(low + (high - low)/2)

      while (i <= j) {
        while (arr(i) < pivot) i += 1
        while (arr(j) > pivot) j -= 1

        if (i <= j) {
          exchangeNumbers(arr, i, j)
          i += 1
          j -= 1
        }
      }

      if (low < j) sort(arr, low, j)
      if (i < high) sort(arr, i, high)
    }

    def exchangeNumbers(arr: Array[Int], i: Int, j: Int): Unit = {
      val temp = arr(i)
      arr(i) = arr(j)
      arr(j) = temp
      pairs += ((i, j))
    }

    sort(arr.map(numToRank), 0, arr.length-1)

    pairs.filter(pair => pair._1 != pair._2).toArray
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    val input = inputShape.toSingle().toArray
    val outputShape = input.clone()
    var i = 0
    while (i < dims.length) {
      outputShape(i + 1) = input(dims(i))
      i += 1
    }
    Shape(outputShape)
  }

  override def doBuild(inputShape: Shape): AbstractModule[Tensor[T], Tensor[T], T] = {
    val swaps = permToPair(dims.map(x => x - 1)).map(pair => (pair._1 + 2, pair._2 + 2))
    val layer = Transpose(swaps)
    layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
  }
}

object Permute {
  def apply[@specialized(Float, Double) T: ClassTag](
    dims: Array[Int],
    inputShape: Shape = null)(implicit ev: TensorNumeric[T]): Permute[T] = {
    new Permute[T](dims, inputShape)
  }
} 
Example 33
Source File: FrameManager.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import java.util.concurrent.atomic.AtomicInteger

import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.tf.{Exit, MergeOps, NextIteration}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer


  class Frame[T] private[FrameManager] (
    val name: String,
    val parent: Option[Frame[T]]
  ) {
    // Sync all next iteration nodes execution
    private[bigdl] var barrier: AtomicInteger = new AtomicInteger(0)
    // User can use NextIteration to sync execution. This is a list of those type of nodes
    private[bigdl] val waitingNodes: ArrayBuffer[ModuleNode[T]] = new ArrayBuffer[ModuleNode[T]]()

    // Nodes should be refreshed in a iteration of the frame
    private[bigdl] val nodes: ArrayBuffer[ModuleNode[T]] = new ArrayBuffer[ModuleNode[T]]()
  }
} 
Example 34
Source File: TimeDistributedCriterion.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.TensorCriterion
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Engine

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future
import scala.reflect.ClassTag


    require(input.size(dimension) == target.size(dimension),
      s"target should have as many elements as input, " +
        s"input ${input.size(dimension)}, target ${target.size(dimension)}")
    gradInput.resizeAs(input).zero()

    val nstep = input.size(dimension)

    var i = 0
    while (i < nstep) {
      val _i = i + 1
      results(i) = Engine.model.invoke(() => {
        fInput = input.select(dimension, _i)
        fTarget = target.select(dimension, _i)
        _gradInput = gradInput.select(dimension, _i)
        _gradInput.copy(cells(_i - 1).updateGradInput(fInput, fTarget).toTensor[T])
        if (sizeAverage) {
          _gradInput = _gradInput.div(ev.fromType[Int](nstep))
        }
      })
      i += 1
    }
    Engine.model.sync(results)
    gradInput
  }

  override def canEqual(other: Any): Boolean = other.isInstanceOf[TimeDistributedCriterion[T]]
}

object TimeDistributedCriterion {
  def apply[@specialized(Float, Double) T: ClassTag](
    critrn: TensorCriterion[T] = null, sizeAverage: Boolean = false, dimension: Int = 2)
    (implicit ev: TensorNumeric[T]) : TimeDistributedCriterion[T] = {
    new TimeDistributedCriterion[T](critrn, sizeAverage, dimension)
  }
} 
Example 35
Source File: ExpandSize.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


class ExpandSize[T: ClassTag](targetSizes: Array[Int])
   (implicit ev: TensorNumeric[T]) extends AbstractModule[Tensor[T], Tensor[T], T] {

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    require(targetSizes.length == input.dim(),
      s"the number of dimensions provided must equal ${input.dim()}")
    val tensorDim = input.dim()
    val tensorStride = input.stride()
    val tensorSize = input.size()

    var i = 0
    while (i < tensorDim) {
      if (targetSizes(i) != -1) {
        if (tensorSize(i) == 1) {
          tensorSize(i) = targetSizes(i)
          tensorStride(i) = 0
        } else if (tensorSize(i) != targetSizes(i)) {
          throw new UnsupportedOperationException(
            "incorrect size: only supporting singleton expansion (size=1)")
        }
      }
      i += 1
    }

    output.set(input.storage(), input.storageOffset(), tensorSize, tensorStride)
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    val tensorDim = input.dim()
    val tensorSize = input.size()

    gradInput = Tensor[T](tensorSize)
    val expandDim = new ArrayBuffer[Int]()
    var i = 0
    while (i < tensorDim) {
      if (targetSizes(i) != -1) {
        if (tensorSize(i) == 1 && targetSizes(i) != 1) {
          expandDim.append(i + 1)
        }
      }
      i += 1
    }

    i = expandDim.size - 1
    val sizes = gradOutput.size()
    var _gradOutput = gradOutput
    while (i >= 0) {
      var start = 1
      sizes(expandDim(i) - 1) = 1
      val _gradInput = Tensor[T](sizes)
      while (start <= gradOutput.size(expandDim(i))) {
        val x = _gradOutput.narrow(expandDim(i), start, 1)
        _gradInput.add(x)
        start += 1
      }
      _gradOutput = _gradInput
      i -= 1
    }
    gradInput = _gradOutput
    gradInput
  }

  override def toString: String = s"ExpandSize"
}

object ExpandSize {
  def apply[@specialized(Float, Double) T: ClassTag](targetSizes: Array[Int])
     (implicit ev: TensorNumeric[T]) : ExpandSize[T] = {
    new ExpandSize[T](targetSizes)
  }
} 
Example 36
Source File: Utils.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.quantized

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.nn.tf.WithoutInput
import com.intel.analytics.bigdl.nn.{Cell, Container, Graph, Input, TimeDistributed, Linear => NNLinear, SpatialConvolution => NNConv, SpatialDilatedConvolution => NNDilatedConv}
import com.intel.analytics.bigdl.tensor.{QuantizedTensor, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Node
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object Utils {
  type ModuleNode[R] = AbstractModule[Activity, Activity, R]
  type SeqNodes[R] = Seq[Node[ModuleNode[R]]]
  type ArrayNodes[R] = Array[Node[ModuleNode[R]]]
  type ANode[R] = Node[ModuleNode[R]]
  type AbsModule[R] = AbstractModule[Activity, Activity, R]

  
  def reorganizeParameters[T: ClassTag](parameters: Array[Tensor[T]])(
    implicit ev: TensorNumeric[T]): Tensor[T] = {
    var length = 0
    for (i <- parameters.indices) {
      if (!parameters(i).isInstanceOf[QuantizedTensor[T]]) {
        length += parameters(i).nElement()
      }
    }

    val result = Tensor[T](length)

    var offset = 0
    for (i <- parameters.indices) {
      val parameter = parameters(i)

      if (!parameter.isInstanceOf[QuantizedTensor[T]]) {
        val length = parameter.nElement()

        val (src, srcOffset) = (parameter.storage().array(), parameter.storageOffset() - 1)
        val (dst, dstOffset) = (result.storage().array(), offset)

        val (size, stride) = (parameter.size(), parameter.stride())

        System.arraycopy(src, srcOffset, dst, dstOffset, length)
        parameter.set(result.storage(), offset + 1, size, stride)

        offset += length
      }
    }

    result
  }
} 
Example 37
Source File: Any.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class Any[T: ClassTag](keepDim : Boolean = false, startFromZero : Boolean = false)
  (implicit ev: TensorNumeric[T]) extends Operation[Table,
  Tensor[Boolean], T] {

  output = Tensor[Boolean]()

  private var buffer = Tensor[Boolean]()

  override def updateOutput(input: Table): Tensor[Boolean] = {
    val data = input[Tensor[Boolean]](1)
    val indices = input[Tensor[Int]](2)
    require(indices.nDimension() == 1 || indices.isScalar, "indices must be 1D tensor or scala")
    output.resizeAs(data)
    buffer.resizeAs(data).copy(data)
    val reduceDims = new ArrayBuffer[Int]()
    val size = output.size()
    if (indices.isScalar) {
      val dim = if (indices.value() < 0) {
        data.nDimension() + indices.value() + 1
      } else if (startFromZero) {
        indices.value() + 1
      } else {
        indices.value()
      }

      if (size(dim - 1) != 1) {
        size(dim - 1) = 1
        reduceDims += dim
        output.resize(size)
        buffer.reduce(dim, output, (a, b) => a || b)
        buffer.resizeAs(output).copy(output)
      }
    } else {
      var i = 1
      while (i <= indices.size(1)) {
        val dim = if (indices.valueAt(i) < 0) {
          data.nDimension() + indices.valueAt(i) + 1
        } else if (startFromZero) {
          indices.valueAt(i) + 1
        } else {
          indices.valueAt(i)
        }
        if (size(dim - 1) != 1) {
          size(dim - 1) = 1
          reduceDims += dim
          output.resize(size)
          buffer.reduce(dim, output, (a, b) => a || b)
          buffer.resizeAs(output).copy(output)
        }
        i += 1
      }
    }

    if (!keepDim) {
      val sizeBuffer = new ArrayBuffer[Int]()
      var i = 1
      while (i <= data.nDimension()) {
        if (!reduceDims.contains(i)) sizeBuffer.append(data.size(i))
        i += 1
      }
      output.resize(sizeBuffer.toArray)
    }
    output
  }

  override def clearState(): this.type = {
    super.clearState()
    buffer.set()
    this
  }
}

object Any {
  def apply[T: ClassTag](keepDim: Boolean = false, startFromZero : Boolean = false)
    (implicit ev: TensorNumeric[T]): Any[T] = new Any[T](keepDim, startFromZero)
} 
Example 38
Source File: CategoricalColVocaList.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.HashFunc

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag



class CategoricalColVocaList[T: ClassTag](
  val vocaList: Array[String],
  val strDelimiter: String = ",",
  val isSetDefault: Boolean = false,
  val numOovBuckets: Int = 0
) (implicit ev: TensorNumeric[T])
  extends Operation[Tensor[String], Tensor[Int], T]{

  private val vocaLen = vocaList.length
  private val vocaMap = vocaList.zipWithIndex.toMap

  require(numOovBuckets >= 0,
    "numOovBuckets is a negative integer")
  require(!(isSetDefault && numOovBuckets != 0),
    "defaultValue and numOovBuckets are both specified")
  require(vocaLen > 0,
    "the vocabulary list is empty")
  require(vocaLen == vocaMap.size,
    "the vocabulary list contains duplicate keys")

  output = Tensor[Int]()

  override def updateOutput(input: Tensor[String]): Tensor[Int] = {

    input.squeeze()
    val rows = input.size(dim = 1)

    val cols = if (numOovBuckets==0) {
      if (isSetDefault) vocaLen + 1 else vocaLen
    }
    else {
      vocaLen + numOovBuckets
    }
    val shape = Array(rows, cols)
    val indices0 = new ArrayBuffer[Int]()
    val indices1 = new ArrayBuffer[Int]()
    val values = new ArrayBuffer[Int]()

    var i = 1
    while (i <= rows) {
      var feaStrArr = input.valueAt(i).split(strDelimiter)
      if (!isSetDefault && numOovBuckets == 0) {
        feaStrArr = feaStrArr.filter(x => vocaMap.contains(x))
      }
      var j = 0
      while (j < feaStrArr.length) {
        val mapVal = numOovBuckets==0 match {
          case true =>
            vocaMap.getOrElse(feaStrArr(j), vocaMap.size)
          case false =>
            vocaMap.getOrElse(feaStrArr(j),
              HashFunc.stringHashBucket32(feaStrArr(j), numOovBuckets) + vocaLen)
        }
        indices0 += i-1
        indices1 += j
        values += mapVal
        j += 1
      }
      i += 1
    }
    val indices = Array(indices0.toArray, indices1.toArray)
    output = Tensor.sparse(indices, values.toArray, shape)
    output
  }
}

object CategoricalColVocaList {
  def apply[T: ClassTag](
    vocaList: Array[String],
    strDelimiter: String = ",",
    isSetDefault: Boolean = false,
    numOovBuckets: Int = 0
  ) (implicit ev: TensorNumeric[T]): CategoricalColVocaList[T]
  = new CategoricalColVocaList[T](
    vocaList = vocaList,
    strDelimiter = strDelimiter,
    isSetDefault = isSetDefault,
    numOovBuckets = numOovBuckets
  )
} 
Example 39
Source File: CategoricalColHashBucket.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.util.hashing.MurmurHash3



class CategoricalColHashBucket[T: ClassTag](
  val hashBucketSize: Int,
  val strDelimiter: String = ",",
  val isSparse: Boolean = true
  )(implicit ev: TensorNumeric[T])
  extends Operation[Tensor[String], Tensor[Int], T] {

  output = Tensor[Int]()

  override def updateOutput(input: Tensor[String]): Tensor[Int] = {
    val rows = input.size(dim = 1)
    val indices0 = new ArrayBuffer[Int]()
    val indices1 = new ArrayBuffer[Int]()
    val values = new ArrayBuffer[Int]()
    var i = 1
    var max_fea_len = 0
    while(i <= rows) {
      val feaStrArr = input.valueAt(i, 1).split(strDelimiter)
      max_fea_len = math.max(max_fea_len, feaStrArr.length)
      var j = 0
      while(j < feaStrArr.length) {
        val hashVal = MurmurHash3.stringHash(feaStrArr(j)) % hashBucketSize match {
          case v if v < 0 => v + hashBucketSize
          case v => v
        }
        indices0 += i-1
        indices1 += j
        values += hashVal
        j += 1
      }
      i += 1
    }
    val indices = Array(indices0.toArray, indices1.toArray)
    val shape = Array(rows, max_fea_len)
    output = isSparse match {
      case true =>
        Tensor.sparse(indices, values.toArray, shape)
      case false =>
        Tensor.dense(Tensor.sparse(indices, values.toArray, shape))
    }
    output
  }
}

object CategoricalColHashBucket{
  def apply[T: ClassTag](
      hashBucketSize: Int,
      strDelimiter: String = ",",
      isSparse: Boolean = true)
      (implicit ev: TensorNumeric[T])
  : CategoricalColHashBucket[T] = new CategoricalColHashBucket[T](
    hashBucketSize = hashBucketSize,
    strDelimiter = strDelimiter,
    isSparse = isSparse
  )
} 
Example 40
Source File: Sum.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table
import com.intel.analytics.bigdl.nn.{Sum => SumLayer}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class Sum[T: ClassTag, D: ClassTag](val keepDims: Boolean, val startFromZero: Boolean = false)
  (implicit ev: TensorNumeric[T], ev2: TensorNumeric[D])
  extends Operation[Table, Tensor[D], T] {

  private val sum: SumLayer[D] = SumLayer[D](squeeze = !keepDims)

  output = Tensor[D]()

  override def updateOutput(input: Table): Tensor[D] = {
    val data = input[Tensor[D]](1)
    val dims = input[Tensor[Int]](2)

    output.resizeAs(data).copy(data)

    val sumDims = if (dims.isEmpty) {
      return output
    } else if (dims.isScalar) {
      Array(if (startFromZero) dims.value() + 1 else dims.value())
    } else {
      require(dims.nDimension() == 1, s"Only accept 1D as dims, but now is ${dims.nDimension()}")
      val buffer = new ArrayBuffer[Int]()
      dims.apply1(a => {
        buffer.append(if (startFromZero) a + 1 else a)
        a
      })
      buffer.toArray.sortWith(_ > _)
    }

    var i = 0
    while(i < sumDims.length) {
      sum.changeSumDims(sumDims(i))
      val tmp = sum.updateOutput(output)
      output.resizeAs(tmp).copy(tmp)
      i += 1
    }

    output
  }

  override def getClassTagNumerics() : (Array[ClassTag[_]], Array[TensorNumeric[_]]) = {
    (Array[ClassTag[_]](scala.reflect.classTag[T], scala.reflect.classTag[D]),
      Array[TensorNumeric[_]](ev, ev2))
  }
}

object Sum {
  def apply[T: ClassTag, D: ClassTag](keepDims: Boolean = false, startFromZero: Boolean = false)
    (implicit ev: TensorNumeric[T], ev2: TensorNumeric[D]): Sum[T, D] =
    new Sum(keepDims, startFromZero)
} 
Example 41
Source File: Kv2Tensor.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag



class Kv2Tensor[T: ClassTag, D: ClassTag](
  val kvDelimiter: String,
  val itemDelimiter: String,
  val transType: Int
  )(implicit ev: TensorNumeric[T], ev2: TensorNumeric[D])
  extends Operation[Table, Tensor[D], T]{

  output = Activity.allocate[Tensor[D], D]()

  override def updateOutput(input: Table): Tensor[D] = {
    val kvTensor = input[Tensor[String]](1)
    val feaLen = input[Tensor[Int]](2).value()
    val indices0 = new ArrayBuffer[Int]()
    val indices1 = new ArrayBuffer[Int]()
    val values = new ArrayBuffer[D]()
    val rows = kvTensor.size(dim = 1)
    val shape = Array(rows, feaLen)

    var i = 1
    while(i<=rows) {
      val kvFeaString = kvTensor.select(1, i).valueAt(1)
      kvFeaString.split(kvDelimiter).foreach { kv =>
        indices0 += i-1
        indices1 += kv.split(itemDelimiter)(0).toInt
        ev2.getType() match {
          case DoubleType =>
            values += kv.split(itemDelimiter)(1).toDouble.asInstanceOf[D]
          case FloatType =>
            values += kv.split(itemDelimiter)(1).toFloat.asInstanceOf[D]
          case t => throw new NotImplementedError(s"$t is not supported")
        }
      }
      i += 1
    }

    val indices = Array(indices0.toArray, indices1.toArray)
    val resTensor = transType match {
      case 0 =>
        Tensor.dense(Tensor.sparse(indices, values.toArray, shape))
      case 1 =>
        Tensor.sparse(indices, values.toArray, shape)
    }
    output = resTensor
    output
  }

  override def getClassTagNumerics() : (Array[ClassTag[_]], Array[TensorNumeric[_]]) = {
    (Array[ClassTag[_]](scala.reflect.classTag[T], scala.reflect.classTag[D]),
      Array[TensorNumeric[_]](ev, ev2))
  }
}

object Kv2Tensor{
  def apply[T: ClassTag, D: ClassTag](
     kvDelimiter: String = ",",
     itemDelimiter: String = ":",
     transType: Int = 0)
     (implicit ev: TensorNumeric[T], ev2: TensorNumeric[D]): Kv2Tensor[T, D]
  = new Kv2Tensor[T, D](
    kvDelimiter = kvDelimiter,
    itemDelimiter = itemDelimiter,
    transType = transType
  )
} 
Example 42
Source File: All.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

class All[T: ClassTag](keepDim : Boolean = false, startFromZero : Boolean = false)
  (implicit ev: TensorNumeric[T]) extends Operation[Table,
  Tensor[Boolean], T] {

  output = Tensor[Boolean]()

  private var buffer = Tensor[Boolean]()

  override def updateOutput(input: Table): Tensor[Boolean] = {
    val data = input[Tensor[Boolean]](1)
    val indices = input[Tensor[Int]](2)
    require(indices.nDimension() == 1 || indices.isScalar, "indices must be 1D tensor or scala")
    output.resizeAs(data)
    buffer.resizeAs(data).copy(data)
    val reduceDims = new ArrayBuffer[Int]()
    val size = output.size()
    if (indices.isScalar) {
      val dim = if (indices.value() < 0) {
        data.nDimension() + indices.value() + 1
      } else if (startFromZero) {
        indices.value() + 1
      } else {
        indices.value()
      }

      if (size(dim - 1) != 1) {
        size(dim - 1) = 1
        reduceDims += dim
        output.resize(size)
        buffer.reduce(dim, output, (a, b) => a && b)
        buffer.resizeAs(output).copy(output)
      }
    } else {
      var i = 1
      while (i <= indices.size(1)) {
        val dim = if (indices.valueAt(i) < 0) {
          data.nDimension() + indices.valueAt(i) + 1
        } else if (startFromZero) {
          indices.valueAt(i) + 1
        } else {
          indices.valueAt(i)
        }
        if (size(dim - 1) != 1) {
          size(dim - 1) = 1
          reduceDims += dim
          output.resize(size)
          buffer.reduce(dim, output, (a, b) => a && b)
          buffer.resizeAs(output).copy(output)
        }
        i += 1
      }
    }

    if (!keepDim) {
      val sizeBuffer = new ArrayBuffer[Int]()
      var i = 1
      while (i <= data.nDimension()) {
        if (!reduceDims.contains(i)) sizeBuffer.append(data.size(i))
        i += 1
      }
      output.resize(sizeBuffer.toArray)
    }
    output
  }

  override def clearState(): this.type = {
    super.clearState()
    buffer.set()
    this
  }
}

object All {
  def apply[T: ClassTag](keepDim: Boolean = false, startFromZero : Boolean = false)
    (implicit ev: TensorNumeric[T]): All[T] = new All[T](keepDim, startFromZero)
} 
Example 43
Source File: ParallelTable.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag



@SerialVersionUID(- 1197848941394786045L)
class ParallelTable[T: ClassTag]
  (implicit ev: TensorNumeric[T]) extends DynamicContainer[Table, Table, T] {

  override def updateOutput(input: Table): Table = {
    var i = 0
    while (i < input.length()) {
      output.update(i + 1, modules(i).forward(input(i + 1)))
      i += 1
    }
    output
  }

  override def updateGradInput(input: Table, gradOutput: Table): Table = {
    var i = 0
    while (i < input.length()) {
      gradInput.update(i + 1, modules(i).updateGradInput(input(i + 1), gradOutput(i + 1)))
      i += 1
    }
    gradInput
  }

  override def accGradParameters(input: Table, gradOutput: Table): Unit = {
    var i = 0
    while (i < input.length()) {
      modules(i).accGradParameters(input(i + 1), gradOutput(i + 1))
      i += 1
    }
  }

  override def backward(input: Table, gradOutput: Table): Table = {
    val before = System.nanoTime()
    var i = 0
    while (i < input.length()) {
      gradInput.update(i + 1, modules(i).backward(input(i + 1), gradOutput(i + 1)))
      i += 1
    }
    backwardTime += System.nanoTime() - before
    gradInput
  }

  override def getEndNodes(startNodes: Array[ModuleNode[T]]): Array[ModuleNode[T]] = {
    val outputs = ArrayBuffer[ModuleNode[T]]()
    var outputTuple: Array[ModuleNode[T]] = null
    require(startNodes.length == modules.length, s"ParallelTable: " +
      s"startNodes length ${startNodes.length} is more than modules length ${modules.length}")
    for (i <- 0 to modules.size - 1) {
      outputTuple = modules(i).getEndNodes(Array(startNodes(i)))
      outputs ++= outputTuple
    }
    outputs.toArray
  }

  override def toString: String = {
    val tab = "\t"
    val line = "\n"
    val next = "  |`-> "
    val lastNext = "   `-> "
    val ext = "  |    "
    val extlast = "       "
    val last = "   ... -> "
    var str = "nn.ParallelTable"
    str = str + " {" + line + tab + "input"
    var i = 1
    while (i <= modules.length) {
      if (i == modules.length) {
        str = str + line + tab + lastNext + "(" + i + "): " +
          modules(i-1).toString.replace(line, line + tab + extlast)
      } else {
        str = str + line + tab + next + "(" + i + "): " +
          modules(i-1).toString.replace(line, line + tab + ext)
      }
      i += 1
    }
    str = str + line + tab + last + "output"
    str = str + line + "}"
    str
  }
}

object ParallelTable {
  def apply[@specialized(Float, Double) T: ClassTag]()
      (implicit ev: TensorNumeric[T]) : ParallelTable[T] = {
    new ParallelTable[T]()
  }
} 
Example 44
Source File: MultiCriterion.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.{Activity, AbstractCriterion}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.T

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag



@SerialVersionUID(- 8679064077837483164L)
class MultiCriterion[@specialized(Float, Double) T: ClassTag]
(implicit ev: TensorNumeric[T]) extends AbstractCriterion[Activity, Activity, T] {

  private val weights = new ArrayBuffer[Double]
  private val criterions = T()

  def add(criterion: AbstractCriterion[Activity, Activity, T], weight: Double = 1): Unit = {
    criterions.insert(criterions.length() + 1, criterion)
    weights.append(weight)
  }
  override def updateOutput(input: Activity, target: Activity): T = {
    var i = 1
    while (i <= criterions.length) {
      output = ev.plus(output, ev.times(ev.fromType(weights(i-1)),
        criterions[AbstractCriterion[Activity, Activity, T]](i).updateOutput(input, target)))
      i +=1
    }
    output
  }

  override def updateGradInput(input: Activity, target: Activity): Activity = {
    gradInput = Utils.recursiveResizeAs[T](gradInput,
      input)
    Utils.recursiveFill[T](gradInput, 0)
    var i = 1
    while (i <= criterions.length) {
      Utils.recursiveAdd(gradInput, weights(i - 1),
        criterions[AbstractCriterion[Activity, Activity, T]](i).updateGradInput(input, target))
      i += 1
    }
    gradInput
  }

  override def canEqual(other: Any): Boolean = other.isInstanceOf[MultiCriterion[T]]

  override def equals(other: Any): Boolean = other match {
    case that: MultiCriterion[T] =>
      super.equals(that) &&
      (that canEqual this) &&
        weights == that.weights
    case _ => false
  }

  override def hashCode(): Int = {
    def getHashCode(a: Any): Int = if (a == null) 0 else a.hashCode()
    val state = Seq(super.hashCode(), weights)
    state.map(getHashCode).foldLeft(0)((a, b) => 31 * a + b)
  }

  override def toString(): String = {
    s"nn.MultiCriterion"
  }
}

object MultiCriterion {
  def apply[@specialized(Float, Double) T: ClassTag]()
      (implicit ev: TensorNumeric[T]) : MultiCriterion[T] = {
    new MultiCriterion[T]()
  }
} 
Example 45
Source File: Metrics.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.google.common.util.concurrent.AtomicDouble
import org.apache.spark.{Accumulable, Accumulator, SparkContext}

import scala.collection.mutable.{ArrayBuffer, Map}


class Metrics extends Serializable {
  private val localMetricsMap: Map[String, LocalMetricsEntry] = Map()
  private val aggregateDistributeMetricsMap: Map[String, AggregateDistributeMetricsEntry] = Map()
  private val distributeMetricsMap: Map[String, DistributeMetricsEntry] = Map()

  def add(name: String, value: Double): this.type = {
    require(localMetricsMap.contains(name) || aggregateDistributeMetricsMap.contains(name) ||
      distributeMetricsMap.contains(name))
    if (localMetricsMap.contains(name)) {
      localMetricsMap(name).value.addAndGet(value)
    }

    if (aggregateDistributeMetricsMap.contains(name)) {
      aggregateDistributeMetricsMap(name).value += value
    }

    if (distributeMetricsMap.contains(name)) {
      distributeMetricsMap(name).value += value
    }
    this
  }

  def set(name: String, value: Double, parallel: Int = 1): this.type = {
    require(!aggregateDistributeMetricsMap.contains(name), "duplicated distribute metric")
    require(!distributeMetricsMap.contains(name), "duplicated distribute metric2")
    if (localMetricsMap.contains(name)) {
      localMetricsMap(name).value.set(value)
      localMetricsMap(name).parallel = parallel
    } else {
      localMetricsMap(name) = LocalMetricsEntry(new AtomicDouble(value), parallel)
    }
    this
  }

  def set(name: String, value: Double, sc: SparkContext, parallel: Int): this.type = {
    require(!localMetricsMap.contains(name), "duplicated local metric")
    if (aggregateDistributeMetricsMap.contains(name)) {
      aggregateDistributeMetricsMap(name).value.setValue(value)
      aggregateDistributeMetricsMap(name).parallel = parallel
    } else {
      aggregateDistributeMetricsMap(name) =
        AggregateDistributeMetricsEntry(sc.accumulator(value, name), parallel)
    }
    this
  }

  def set(name: String, value: ArrayBuffer[Double], sc: SparkContext): this.type = {
    require(!localMetricsMap.contains(name), "duplicated local metric")
    require(!aggregateDistributeMetricsMap.contains(name), "duplicated distribute metric")
    if (distributeMetricsMap.contains(name)) {
      distributeMetricsMap(name).value.setValue(value)
    } else {
      distributeMetricsMap(name) = DistributeMetricsEntry(sc.accumulableCollection(value))
    }
    this
  }

  def get(name: String): (Double, Int) = {
    require(localMetricsMap.contains(name) || aggregateDistributeMetricsMap.contains(name))
    if (localMetricsMap.contains(name)) {
      (localMetricsMap(name).value.get(), localMetricsMap(name).parallel)
    } else {
      (aggregateDistributeMetricsMap(name).value.value,
        aggregateDistributeMetricsMap(name).parallel)
    }
  }

  def get(name: String, number: Int): Array[Double] = {
    require(distributeMetricsMap.contains(name))
    distributeMetricsMap(name).value.value.toArray.dropRight(number)
  }

  def summary(unit: String = "s", scale: Double = 1e9): String = {
    "========== Metrics Summary ==========\n" +
      localMetricsMap.map(
        entry => s"${entry._1} : ${entry._2.value.get() / entry._2.parallel / scale} $unit\n")
        .mkString("") +
      aggregateDistributeMetricsMap.map(
        entry => s"${entry._1} : ${entry._2.value.value / entry._2.parallel / scale} $unit\n")
        .mkString("") +
      distributeMetricsMap.map { entry =>
        s"${entry._1} : ${entry._2.value.value.map(_ / scale).mkString(" ")} \n"
      }.mkString("") +
      "====================================="
  }
}


private case class LocalMetricsEntry(value: AtomicDouble, var parallel: Int)

private case class AggregateDistributeMetricsEntry(value: Accumulator[Double], var parallel: Int)

private case class DistributeMetricsEntry(value: Accumulable[ArrayBuffer[Double], Double]) 
Example 46
Source File: BatchSamplerSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.transform.vision.image.label.roi

import com.intel.analytics.bigdl.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.transform.vision.image.util.BoundingBox
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

class BatchSamplerSpec extends FlatSpec with Matchers {
  "batch sampler with no change" should "work properly" in {
    val sampler = new BatchSampler(maxTrials = 1)
    val unitBox = BoundingBox(0, 0, 1, 1)
    val boxes = Tensor(Storage(Array(0.582296, 0.334719, 0.673582, 0.52183,
      0.596127, 0.282744, 0.670816, 0.449064,
      0.936376, 0.627859, 0.961272, 0.733888,
      0.896266, 0.640333, 0.923928, 0.740125).map(x => x.toFloat))).resize(4, 4)
    val classes = Tensor[Float](4).randn()
    val target = RoiLabel(classes, boxes)
    val sampledBoxes = new ArrayBuffer[BoundingBox]()
    sampler.sample(unitBox, target, sampledBoxes)

    sampledBoxes.length should be(1)
    sampledBoxes(0) should be(unitBox)
  }

  "satisfySampleConstraint with minOverlap 0.1" should "work properly" in {
    val boxes = Tensor(Storage(Array(0.418, 0.396396, 0.55, 0.666667,
      0.438, 0.321321, 0.546, 0.561562,
      0.93, 0.81982, 0.966, 0.972973,
      0.872, 0.837838, 0.912, 0.981982).map(x => x.toFloat))).resize(4, 4)
    val classes = Tensor[Float](4).randn()
    val target = RoiLabel(classes, boxes)

    val sampledBox = BoundingBox(0.114741f, 0.248062f, 0.633665f, 0.763736f)
    val sampler = new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.1))

    sampler.satisfySampleConstraint(sampledBox, target) should be(true)
  }

  "satisfySampleConstraint with minOverlap 0.3" should "work properly" in {
    val boxes = Tensor(Storage(Array(0.418, 0.396396, 0.55, 0.666667,
      0.438, 0.321321, 0.546, 0.561562,
      0.93, 0.81982, 0.966, 0.972973,
      0.872, 0.837838, 0.912, 0.981982).map(x => x.toFloat))).resize(4, 4)
    val classes = Tensor[Float](4).randn()
    val target = RoiLabel(classes, boxes)

    val sampledBox = BoundingBox(0.266885f, 0.416113f, 0.678256f, 0.67208f)
    val sampler = new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
      minOverlap = Some(0.3))

    sampler.satisfySampleConstraint(sampledBox, target) should be(true)
  }

  "batch samplers" should "work properly" in {
    val boxes = Tensor(Storage(Array(0.418, 0.396396, 0.55, 0.666667,
      0.438, 0.321321, 0.546, 0.561562,
      0.93, 0.81982, 0.966, 0.972973,
      0.872, 0.837838, 0.912, 0.981982).map(x => x.toFloat))).resize(4, 4)
    val classes = Tensor[Float](4).randn()
    val target = RoiLabel(classes, boxes)
    val sampledBoxes = new ArrayBuffer[BoundingBox]()
    val batchSamplers = Array(
      new BatchSampler(maxTrials = 1),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        minOverlap = Some(0.1)),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        minOverlap = Some(0.3)),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        minOverlap = Some(0.5)),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        minOverlap = Some(0.7)),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        minOverlap = Some(0.9)),
      new BatchSampler(minScale = 0.3, minAspectRatio = 0.5, maxAspectRatio = 2,
        maxOverlap = Some(1.0)))
    BatchSampler.generateBatchSamples(target, batchSamplers, sampledBoxes)

    sampledBoxes.foreach(box => {
      println(box)
    })
  }
} 
Example 47
Source File: BigDLSpecHelper.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import java.io.{File => JFile}

import org.apache.log4j.Logger
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

abstract class BigDLSpecHelper extends FlatSpec with Matchers with BeforeAndAfter {
  protected val logger = Logger.getLogger(getClass)

  private val tmpFiles : ArrayBuffer[JFile] = new ArrayBuffer[JFile]()

  protected def createTmpFile(): JFile = {
    val file = java.io.File.createTempFile("UnitTest", "BigDLSpecBase")
    logger.info(s"created file $file")
    tmpFiles.append(file)
    file
  }

  protected def getFileFolder(path: String): String = {
    path.substring(0, path.lastIndexOf(JFile.separator))
  }

  protected def getFileName(path: String): String = {
    path.substring(path.lastIndexOf(JFile.separator) + 1)
  }

  def doAfter(): Unit = {}

  def doBefore(): Unit = {}

  before {
    doBefore()
  }

  after {
    doAfter()
    tmpFiles.foreach(f => {
      if (f.exists()) {
        require(f.isFile, "cannot clean folder")
        f.delete()
        logger.info(s"deleted file $f")
      }
    })
  }
} 
Example 48
Source File: Kv2TensorSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.intel.analytics.bigdl.tensor.{DenseType, SparseType, Tensor}
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.{T, Table}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

class Kv2TensorSpec extends FlatSpec with Matchers {

  protected def randDoubles(length: Int,
                            lp: Double = 0.0,
                            up: Double = 1.0): Array[Double] = {
    (1 to length).map(_ => lp + (up - lp) * Random.nextDouble()).toArray
  }

  protected def randKVMap(size: Int,
                          numActive: Int,
                          lp: Double = 0.0,
                          up: Double = 1.0): Map[Int, Double] = {
    require(numActive <= size)
    val keys = Random.shuffle((0 until size).toList).take(numActive)
    val values = randDoubles(numActive, lp, up)
    keys.zip(values).toMap
  }
  val batchLen = 3
  val numActive = Array(2, 3, 5)
  val feaLen = 8
  val originData = new ArrayBuffer[String]()
  val originArr = new ArrayBuffer[Table]()
  val indices0 = new ArrayBuffer[Int]()
  val indices1 = new ArrayBuffer[Int]()
  val values = new ArrayBuffer[Double]()
  for (i <- 0 until batchLen) {
    val kvMap = randKVMap(feaLen, numActive(i))
    val kvStr = kvMap.map(data => s"${data._1}:${data._2}").mkString(",")
    originData += kvStr
    originArr += T(kvStr)
    indices0 ++= ArrayBuffer.fill(numActive(i))(i)
    val kvArr = kvMap.toArray
    indices1 ++= kvArr.map(kv => kv._1)
    values ++= kvArr.map(kv => kv._2)
  }
  val originTable = T.array(originArr.toArray)
  val indices = Array(indices0.toArray, indices1.toArray)
  val shape = Array(batchLen, feaLen)

  "Kv2Tensor operation kvString to SparseTensor" should "work correctly" in {
    val input =
      T(
        Tensor[String](originTable),
        Tensor[Int](Array(feaLen), shape = Array[Int]())
      )

    val expectOutput =
      Tensor.sparse[Double](
        indices = indices,
        values = values.toArray,
        shape = shape
      )
    val output = Kv2Tensor[Double, Double](transType = 1)
      .forward(input)

    output should be(expectOutput)
  }

  "Kv2Tensor operation kvString to DenseTensor" should "work correctly" in {
    val input =
      T(
        Tensor[String](originTable),
        Tensor[Int](Array(feaLen), shape = Array[Int]())
      )

    val expectOutput =
      Tensor.dense(Tensor.sparse[Double](
        indices = indices,
        values = values.toArray,
        shape = shape
      ))
    val output = Kv2Tensor[Double, Double](transType = 0)
      .forward(input)

    output should be(expectOutput)
  }
}

class Kv2TensorSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val kv2tensor = Kv2Tensor[Float, Float](
      kvDelimiter = ",", itemDelimiter = ":", transType = 0
    ).setName("kv2tensor")
    val input = T(
      Tensor[String](
        T(T("0:0.1,1:0.2"), T("1:0.3,3:0.5"), T("2:0.15,4:0.25"))),
      Tensor[Int](Array(5), shape = Array[Int]())
    )
    runSerializationTest(kv2tensor, input)
  }
} 
Example 49
Source File: RMSpropSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{T, TestUtils}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

// @com.intel.analytics.bigdl.tags.Parallel
@com.intel.analytics.bigdl.tags.Serial
class RMSpropSpec extends FlatSpec with Matchers {
  val start = System.currentTimeMillis()
  "RMSprop" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T("learningRate" -> 5e-4)
    val optm = new RMSprop[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-4) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }
} 
Example 50
Source File: AdagradSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.utils.{TestUtils, T}
import org.scalatest.{FlatSpec, Matchers}
import com.intel.analytics.bigdl.tensor.Tensor

import scala.collection.mutable.ArrayBuffer

@com.intel.analytics.bigdl.tags.Parallel
class AdagradSpec extends FlatSpec with Matchers {
  "adagrad" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T("learningRate" -> 1e-1)
    val optm = new Adagrad[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += (result._2(0))
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }
} 
Example 51
Source File: LBFGSSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.utils.{TestUtils, T}
import org.scalatest.{FlatSpec, Matchers}
import com.intel.analytics.bigdl.tensor.Tensor

import scala.collection.mutable.ArrayBuffer

@com.intel.analytics.bigdl.tags.Parallel
class LBFGSSpec extends FlatSpec with Matchers {
  "torchLBFGS in regular batch test" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val optm = new LBFGS[Double]
    val result = optm.optimize(TestUtils.rosenBrock, x,
      T("maxIter" -> 100, "learningRate" -> 1e-1))
    val fx = result._2

    println()
    println("Rosenbrock test")
    println()

    println(s"x = $x")
    println("fx = ")
    for (i <- 1 to fx.length) {
      println(s"$i ${fx(i - 1)}")
    }
    println()
    println()

    fx.last < 1e-6 should be(true)
  }

  "torchLBFGS in stochastic test" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val optm = new LBFGS[Double]
    val fx = new ArrayBuffer[Double]()

    val config = T("maxIter" -> 1, "learningRate" -> 1e-1)
    for (i <- 1 to 100) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      fx.append(result._2(0))
    }

    println()
    println("Rosenbrock test")
    println()

    println(s"x = $x")
    println("fx = ")
    for (i <- 1 to fx.length) {
      println(s"$i ${fx(i - 1)}")
    }
    println()
    println()

    fx.last < 1e-6 should be(true)
  }
} 
Example 52
Source File: AdadeltaSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{T, TestUtils}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

@com.intel.analytics.bigdl.tags.Parallel
class AdadeltaSpec extends FlatSpec with Matchers {
  val start = System.currentTimeMillis()
  "adadelta" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T("Epsilon" -> 1e-10)
    val optm = new Adadelta[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-4) should be(true)
    x(Array(1)) should be(1.0 +- 0.02)
    x(Array(2)) should be(1.0 +- 0.02)
  }
} 
Example 53
Source File: AdamaxSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{T, TestUtils}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

@com.intel.analytics.bigdl.tags.Parallel
class AdamaxSpec extends FlatSpec with Matchers {
  val start = System.currentTimeMillis()
  "adamax" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T()
    val optm = new Adamax[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }
} 
Example 54
Source File: AdamSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.nn.{CrossEntropyCriterion, Linear, Sequential}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{Engine, RandomGenerator, T, TestUtils}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

@com.intel.analytics.bigdl.tags.Parallel
class AdamSpec extends FlatSpec with Matchers with BeforeAndAfter {

  before {
    System.setProperty("bigdl.localMode", "true")
    System.setProperty("spark.master", "local[2]")
    Engine.init
  }

  after {
    System.clearProperty("bigdl.localMode")
    System.clearProperty("spark.master")
  }


  val start = System.currentTimeMillis()
  "adam" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T("learningRate" -> 0.002)
    val optm = new Adam[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }

  "ParallelAdam" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val optm = new ParallelAdam[Double](learningRate = 0.002, parallelNum = 2)
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }

} 
Example 55
Source File: TrimmedIndependentPixelEvaluator.scala    From scalismo-faces   with Apache License 2.0 5 votes vote down vote up
package scalismo.faces.sampling.face.evaluators

import scalismo.color.{RGB, RGBA}
import scalismo.faces.image.{ImageBuffer, PixelImage, PixelImageDomain}
import scalismo.sampling.DistributionEvaluator
import scalismo.sampling.evaluators.PairEvaluator

import scala.collection.mutable.ArrayBuffer


    def visualize(values: IndexedSeq[(Double, Int, Int)], domain: PixelImageDomain, callBack: PixelImage[Option[Double]] => Unit): Unit = {
      val buffer = ImageBuffer.makeConstantBuffer[Option[Double]](domain.width, domain.height, None)
      values.foreach { case (lh: Double, x: Int, y: Int) => buffer(x, y) = Some(lh) }
      callBack(buffer.toImage)
    }
    var transparencySum = 0.0
    var values = ArrayBuffer[(Double, Int, Int)]()
    var x: Int = 0
    while (x < reference.width) {
      var y: Int = 0
      while (y < reference.height) {
        val smp = sample(x, y)
        if (smp.a > 1e-4f) {
          val ref = reference(x, y).toRGB
          val fg: Double = pixelEvaluator.logValue(ref, smp.toRGB)
          val bg: Double = bgEvaluator.logValue(ref)
          val entry = (fg - bg, x, y)
          values += entry
        }
        transparencySum += smp.a
        y += 1
      }
      x += 1
    }
    val nCount = math.floor(values.length.toFloat * alphaClamped).toInt
    if (transparencySum > 0 && nCount > 0) {
      //was something rendered on the image?
      val data = values.toIndexedSeq.sortBy { case (d: Double, x: Int, y: Int) => d }
      var sumTrimmed: Double = 0.0
      for (i <- 0 until nCount) {
        sumTrimmed += data(data.size - 1 - i)._1
      }
      if (visualizationCallback.isDefined)
        visualize(data.slice(data.size - 1 - nCount, data.size - 1), reference.domain, visualizationCallback.get)
      sumTrimmed
    } else {
      // nothing was rendered on the image!
      Double.NegativeInfinity
    }
  }

  override def toString: String = {
    val builder = new StringBuilder(128)
    builder ++= "TrimmedIndependentPixelEvaluator("
    builder ++= pixelEvaluator.toString
    builder ++= "/"
    builder ++= bgEvaluator.toString
    builder ++= s"alpha=$alphaClamped"
    builder ++= ")"
    builder.mkString
  }
}

object TrimmedIndependentPixelEvaluator {
  def apply(pixelEvaluator: PairEvaluator[RGB], bgEvaluator: DistributionEvaluator[RGB], alpha: Double) = new TrimmedIndependentPixelEvaluator(pixelEvaluator, bgEvaluator, alpha, None)

  def apply(pixelEvaluator: PairEvaluator[RGB], bgEvaluator: DistributionEvaluator[RGB], alpha: Double, visualisationCallback: PixelImage[Option[Double]] => Unit) = new TrimmedIndependentPixelEvaluator(pixelEvaluator, bgEvaluator, alpha, Some(visualisationCallback))

} 
Example 56
Source File: MorphologicalFilter.scala    From scalismo-faces   with Apache License 2.0 5 votes vote down vote up
package scalismo.faces.image.filter

import scalismo.faces.image.AccessMode._
import scalismo.faces.image._

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


    def perPixel(x: Int, y: Int): A = {
      var kx = 0
      var kernelPixels = new ArrayBuffer[A](width * height)
      while (kx < width) {
        val ix = x + kx - width / 2
        var ky = 0
        while (ky < height) {
          val iy = y + ky - height / 2
          if (structuringElement(kx, ky)) kernelPixels += image(ix, iy)
          ky += 1
        }
        kx += 1
      }
      if (kernelPixels.nonEmpty)
        windowFilter(kernelPixels)
      else
        image(x, y)
    }

    if(width <= 0 || height <= 0)
      image
    else
      PixelImage(image.width, image.height, perPixel, Strict())
  }
}

object MorphologicalFilter {
  def boxElement(size: Int): PixelImage[Boolean] = PixelImage.view(size, size, (x, y) => x >= 0 && x < size && y >= 0 && y < size)
} 
Example 57
Source File: ImmutableSelection.scala    From hacktoberfest-scala-algorithms   with GNU General Public License v3.0 5 votes vote down vote up
package io.github.sentenza.hacktoberfest.algos

import scala.collection.mutable.ArrayBuffer
import scala.math.Ordered


  def quickSelect(list: List[Int], idx: Int): Option[Int] = {
    if (idx < 0 || list.size <= idx) return None

    list match {
      case Nil => None
      case pivot :: rest => {
        val (smaller, larger) = rest partition (_ <= pivot)
        val pivotIdx          = smaller.size

        idx.compare(pivotIdx) match {
          case needleInSmaller if needleInSmaller < 0 =>
            quickSelect(smaller, idx)
          case needleIsPivot if needleIsPivot == 0 => Some(pivot)
          case needleInLarger if needleInLarger > 0 =>
            quickSelect(larger, idx - pivotIdx - 1)
        }
      }
    }
  }
} 
Example 58
Source File: RocksEdgeFetcher.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.core.storage.rocks

import com.typesafe.config.Config
import org.apache.s2graph.core._
import org.apache.s2graph.core.schema.Label
import org.apache.s2graph.core.storage.{SKeyValue, StorageIO, StorageSerDe}
import org.apache.s2graph.core.types.{HBaseType, VertexId}
import org.rocksdb.RocksDB

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}

class RocksEdgeFetcher(val graph: S2GraphLike,
                       val config: Config,
                       val db: RocksDB,
                       val vdb: RocksDB,
                       val serDe: StorageSerDe,
                       val io: StorageIO) extends EdgeFetcher  {
  import RocksStorage._

  override def fetches(queryRequests: Seq[QueryRequest], prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = {
    val futures = for {
      queryRequest <- queryRequests
    } yield {
      val parentEdges = prevStepEdges.getOrElse(queryRequest.vertex.id, Nil)
      val edge = graph.elementBuilder.toRequestEdge(queryRequest, parentEdges)
      val rpc = buildRequest(graph, serDe, queryRequest, edge)
      fetchKeyValues(vdb, db, rpc).map { kvs =>
        val queryParam = queryRequest.queryParam
        val stepResult = io.toEdges(kvs, queryRequest, queryRequest.prevStepScore, false, parentEdges)
        val edgeWithScores = stepResult.edgeWithScores.filter { case edgeWithScore =>
          val edge = edgeWithScore.edge
          val duration = queryParam.durationOpt.getOrElse((Long.MinValue, Long.MaxValue))
          edge.ts >= duration._1 && edge.ts < duration._2
        }

        stepResult.copy(edgeWithScores = edgeWithScores)
      }
    }

    Future.sequence(futures)
  }

  override def fetchEdgesAll()(implicit ec: ExecutionContext) = {
    val edges = new ArrayBuffer[S2EdgeLike]()
    Label.findAll().groupBy(_.hbaseTableName).toSeq.foreach { case (hTableName, labels) =>
      val distinctLabels = labels.toSet

      val iter = db.newIterator()
      try {
        iter.seekToFirst()
        while (iter.isValid) {
          val kv = SKeyValue(table, iter.key(), SKeyValue.EdgeCf, qualifier, iter.value, System.currentTimeMillis())

          serDe.indexEdgeDeserializer(schemaVer = HBaseType.DEFAULT_VERSION).fromKeyValues(Seq(kv), None)
            .filter(e => distinctLabels(e.innerLabel) && e.getDirection() == "out" && !e.isDegree)
            .foreach { edge =>
              edges += edge
            }


          iter.next()
        }

      } finally {
        iter.close()
      }
    }

    Future.successful(edges)
  }
} 
Example 59
Source File: RocksVertexFetcher.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.core.storage.rocks

import com.typesafe.config.Config
import org.apache.hadoop.hbase.util.Bytes
import org.apache.s2graph.core._
import org.apache.s2graph.core.schema.ServiceColumn
import org.apache.s2graph.core.storage.rocks.RocksStorage.{qualifier, table}
import org.apache.s2graph.core.storage.{SKeyValue, StorageIO, StorageSerDe}
import org.apache.s2graph.core.types.HBaseType
import org.rocksdb.RocksDB

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}

class RocksVertexFetcher(val graph: S2GraphLike,
                         val config: Config,
                         val db: RocksDB,
                         val vdb: RocksDB,
                         val serDe: StorageSerDe,
                         val io: StorageIO) extends VertexFetcher {
  private def fetchKeyValues(queryRequest: QueryRequest, vertex: S2VertexLike)(implicit ec: ExecutionContext): Future[Seq[SKeyValue]] = {
    val rpc = RocksStorage.buildRequest(queryRequest, vertex)

    RocksStorage.fetchKeyValues(vdb, db, rpc)
  }

  override def fetchVertices(vertexQueryParam: VertexQueryParam)(implicit ec: ExecutionContext): Future[Seq[S2VertexLike]] = {
    def fromResult(kvs: Seq[SKeyValue], version: String): Seq[S2VertexLike] = {
      if (kvs.isEmpty) Nil
      else serDe.vertexDeserializer(version).fromKeyValues(kvs, None).toSeq.filter(vertexQueryParam.where.get.filter)
    }
    val vertices = vertexQueryParam.vertexIds.map(vId => graph.elementBuilder.newVertex(vId))

    val futures = vertices.map { vertex =>
      val queryParam = QueryParam.Empty
      val q = Query.toQuery(Seq(vertex), Seq(queryParam))
      val queryRequest = QueryRequest(q, stepIdx = -1, vertex, queryParam)

      fetchKeyValues(queryRequest, vertex).map { kvs =>
        fromResult(kvs, vertex.serviceColumn.schemaVersion)
      } recoverWith {
        case ex: Throwable => Future.successful(Nil)
      }
    }

    Future.sequence(futures).map(_.flatten)
  }

  override def fetchVerticesAll()(implicit ec: ExecutionContext) = {
    import scala.collection.mutable

    val vertices = new ArrayBuffer[S2VertexLike]()
    ServiceColumn.findAll().groupBy(_.service.hTableName).toSeq.foreach { case (hTableName, columns) =>
      val distinctColumns = columns.toSet

      val iter = vdb.newIterator()
      val buffer = mutable.ListBuffer.empty[SKeyValue]
      var oldVertexIdBytes = Array.empty[Byte]
      var minusPos = 0

      try {
        iter.seekToFirst()
        while (iter.isValid) {
          val row = iter.key()
          if (!Bytes.equals(oldVertexIdBytes, 0, oldVertexIdBytes.length - minusPos, row, 0, row.length - 1)) {
            if (buffer.nonEmpty)
              serDe.vertexDeserializer(schemaVer = HBaseType.DEFAULT_VERSION).fromKeyValues(buffer, None)
                .filter(v => distinctColumns(v.serviceColumn))
                .foreach { vertex =>
                  vertices += vertex
                }

            oldVertexIdBytes = row
            minusPos = 1
            buffer.clear()
          }
          val kv = SKeyValue(table, iter.key(), SKeyValue.VertexCf, qualifier, iter.value(), System.currentTimeMillis())
          buffer += kv

          iter.next()
        }
        if (buffer.nonEmpty)
          serDe.vertexDeserializer(schemaVer = HBaseType.DEFAULT_VERSION).fromKeyValues(buffer, None)
            .filter(v => distinctColumns(v.serviceColumn))
            .foreach { vertex =>
              vertices += vertex
            }

      } finally {
        iter.close()
      }
    }

    Future.successful(vertices)
  }
} 
Example 60
Source File: BytesUtilV1.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.counter.core.v1

import org.apache.hadoop.hbase.util.Bytes
import org.apache.s2graph.counter.core.TimedQualifier.IntervalUnit
import org.apache.s2graph.counter.core.{TimedQualifier, ExactQualifier, ExactKeyTrait, BytesUtil}
import org.apache.s2graph.counter.models.Counter.ItemType
import org.apache.s2graph.counter.util.Hashes
import scala.collection.mutable.ArrayBuffer

object BytesUtilV1 extends BytesUtil {
  // ExactKey: [hash(2b)][policy(4b)][item(variable)]
  val BUCKET_BYTE_SIZE = Bytes.SIZEOF_SHORT
  val POLICY_ID_SIZE = Bytes.SIZEOF_INT
  val INTERVAL_SIZE = Bytes.SIZEOF_BYTE
  val TIMESTAMP_SIZE = Bytes.SIZEOF_LONG
  val TIMED_QUALIFIER_SIZE = INTERVAL_SIZE + TIMESTAMP_SIZE

  override def getRowKeyPrefix(id: Int): Array[Byte] = {
    Bytes.toBytes(id)
  }

  override def toBytes(key: ExactKeyTrait): Array[Byte] = {
    val buff = new ArrayBuffer[Byte]
    // hash key (2 byte)
    buff ++= Bytes.toBytes(Hashes.murmur3(key.itemKey)).take(BUCKET_BYTE_SIZE)

    buff ++= getRowKeyPrefix(key.policyId)
    buff ++= {
      key.itemType match {
        case ItemType.INT => Bytes.toBytes(key.itemKey.toInt)
        case ItemType.LONG => Bytes.toBytes(key.itemKey.toLong)
        case ItemType.STRING | ItemType.BLOB => Bytes.toBytes(key.itemKey)
      }
    }
    buff.toArray
  }

  override def toBytes(eq: ExactQualifier): Array[Byte] = {
    toBytes(eq.tq) ++ eq.dimension.getBytes
  }

  override def toBytes(tq: TimedQualifier): Array[Byte] = {
    Bytes.toBytes(tq.q.toString) ++ Bytes.toBytes(tq.ts)
  }

  override def toExactQualifier(bytes: Array[Byte]): ExactQualifier = {
    // qualifier: interval, ts, dimension 순서
    val tq = toTimedQualifier(bytes)

    val dimension = Bytes.toString(bytes, TIMED_QUALIFIER_SIZE, bytes.length - TIMED_QUALIFIER_SIZE)
    ExactQualifier(tq, dimension)
  }

  override def toTimedQualifier(bytes: Array[Byte]): TimedQualifier = {
    val interval = Bytes.toString(bytes, 0, INTERVAL_SIZE)
    val ts = Bytes.toLong(bytes, INTERVAL_SIZE)

    TimedQualifier(IntervalUnit.withName(interval), ts)
  }
} 
Example 61
Source File: BytesUtilV2.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.counter.core.v2

import org.apache.hadoop.hbase.util._
import org.apache.s2graph.counter
import org.apache.s2graph.counter.core.TimedQualifier.IntervalUnit
import org.apache.s2graph.counter.core.{TimedQualifier, ExactQualifier, ExactKeyTrait, BytesUtil}
import org.apache.s2graph.counter.models.Counter.ItemType
import org.apache.s2graph.counter.util.Hashes
import scala.collection.mutable.ArrayBuffer

object BytesUtilV2 extends BytesUtil {
  // ExactKey: [hash(1b)][version(1b)][policy(4b)][item(variable)]
  val BUCKET_BYTE_SIZE = Bytes.SIZEOF_BYTE
  val VERSION_BYTE_SIZE = Bytes.SIZEOF_BYTE
  val POLICY_ID_SIZE = Bytes.SIZEOF_INT

  val INTERVAL_SIZE = Bytes.SIZEOF_BYTE
  val TIMESTAMP_SIZE = Bytes.SIZEOF_LONG
  val TIMED_QUALIFIER_SIZE = INTERVAL_SIZE + TIMESTAMP_SIZE

  override def getRowKeyPrefix(id: Int): Array[Byte] = {
    Array(counter.VERSION_2) ++ Bytes.toBytes(id)
  }

  override def toBytes(key: ExactKeyTrait): Array[Byte] = {
    val buff = new ArrayBuffer[Byte]
    // hash byte
    buff ++= Bytes.toBytes(Hashes.murmur3(key.itemKey)).take(BUCKET_BYTE_SIZE)

    // row key prefix
    // version + policy id
    buff ++= getRowKeyPrefix(key.policyId)

    buff ++= {
      key.itemType match {
        case ItemType.INT => Bytes.toBytes(key.itemKey.toInt)
        case ItemType.LONG => Bytes.toBytes(key.itemKey.toLong)
        case ItemType.STRING | ItemType.BLOB => Bytes.toBytes(key.itemKey)
      }
    }
    buff.toArray
  }

  override def toBytes(eq: ExactQualifier): Array[Byte] = {
    val len = eq.dimKeyValues.map { case (k, v) => k.length + 2 + v.length + 2 }.sum
    val pbr = new SimplePositionedMutableByteRange(len)
    for {
      v <- ExactQualifier.makeSortedDimension(eq.dimKeyValues)
    } {
      OrderedBytes.encodeString(pbr, v, Order.ASCENDING)
    }
    toBytes(eq.tq) ++ pbr.getBytes
  }

  override def toBytes(tq: TimedQualifier): Array[Byte] = {
    val pbr = new SimplePositionedMutableByteRange(INTERVAL_SIZE + 2 + TIMESTAMP_SIZE + 1)
    OrderedBytes.encodeString(pbr, tq.q.toString, Order.ASCENDING)
    OrderedBytes.encodeInt64(pbr, tq.ts, Order.DESCENDING)
    pbr.getBytes
  }

  private def decodeString(pbr: PositionedByteRange): Stream[String] = {
    if (pbr.getRemaining > 0) {
      Stream.cons(OrderedBytes.decodeString(pbr), decodeString(pbr))
    }
    else {
      Stream.empty
    }
  }

  override def toExactQualifier(bytes: Array[Byte]): ExactQualifier = {
    val pbr = new SimplePositionedByteRange(bytes)
    ExactQualifier(toTimedQualifier(pbr), {
      val seqStr = decodeString(pbr).toSeq
      val (keys, values) = seqStr.splitAt(seqStr.length / 2)
      keys.zip(values).toMap
    })
  }

  override def toTimedQualifier(bytes: Array[Byte]): TimedQualifier = {
    val pbr = new SimplePositionedByteRange(bytes)
    toTimedQualifier(pbr)
  }

  def toTimedQualifier(pbr: PositionedByteRange): TimedQualifier = {
    TimedQualifier(IntervalUnit.withName(OrderedBytes.decodeString(pbr)), OrderedBytes.decodeInt64(pbr))
  }
} 
Example 62
Source File: AccountStorage.scala    From matcher   with MIT License 5 votes vote down vote up
package com.wavesplatform.dex.db

import java.io.{File, FileInputStream, FileOutputStream}
import java.nio.file.Files
import java.util.Base64

import cats.syntax.either._
import com.google.common.primitives.{Bytes, Ints}
import com.wavesplatform.dex.crypto.Enigma
import com.wavesplatform.dex.db.AccountStorage.Settings.EncryptedFile
import com.wavesplatform.dex.domain.account.KeyPair
import com.wavesplatform.dex.domain.bytes.ByteStr
import com.wavesplatform.dex.domain.crypto
import net.ceedubs.ficus.readers.ValueReader

import scala.collection.mutable.ArrayBuffer

case class AccountStorage(keyPair: KeyPair)

object AccountStorage {

  sealed trait Settings

  object Settings {

    case class InMem(seed: ByteStr)                        extends Settings
    case class EncryptedFile(path: File, password: String) extends Settings

    implicit val valueReader: ValueReader[Settings] = ValueReader.relative[Settings] { config =>
      config.getString("type") match {
        case "in-mem" => InMem(Base64.getDecoder.decode(config.getString("in-mem.seed-in-base64")))
        case "encrypted-file" =>
          EncryptedFile(
            path = new File(config.getString("encrypted-file.path")),
            password = config.getString("encrypted-file.password")
          )
        case x => throw new IllegalArgumentException(s"The type of account storage '$x' is unknown. Please update your settings.")
      }
    }
  }

  def load(settings: Settings): Either[String, AccountStorage] = settings match {
    case Settings.InMem(seed) => Right(AccountStorage(KeyPair(seed)))
    case Settings.EncryptedFile(file, password) =>
      if (file.isFile) {
        val encryptedSeedBytes = readFile(file)
        val key                = Enigma.prepareDefaultKey(password)
        val decryptedBytes     = Enigma.decrypt(key, encryptedSeedBytes)
        AccountStorage(KeyPair(decryptedBytes)).asRight
      } else s"A file '${file.getAbsolutePath}' doesn't exist".asLeft
  }

  def save(seed: ByteStr, to: EncryptedFile): Unit = {
    Files.createDirectories(to.path.getParentFile.toPath)
    val key                = Enigma.prepareDefaultKey(to.password)
    val encryptedSeedBytes = Enigma.encrypt(key, seed.arr)
    writeFile(to.path, encryptedSeedBytes)
  }

  def getAccountSeed(baseSeed: ByteStr, nonce: Int): ByteStr = ByteStr(crypto.secureHash(Bytes.concat(Ints.toByteArray(nonce), baseSeed)))

  def readFile(file: File): Array[Byte] = {
    val reader = new FileInputStream(file)
    try {
      val buff = new Array[Byte](1024)
      val r    = new ArrayBuffer[Byte]
      while (reader.available() > 0) {
        val read = reader.read(buff)
        if (read > 0) {
          r.appendAll(buff.iterator.take(read))
        }
      }
      r.toArray
    } finally {
      reader.close()
    }
  }

  def writeFile(file: File, bytes: Array[Byte]): Unit = {
    val writer = new FileOutputStream(file, false)
    try writer.write(bytes)
    finally writer.close()
  }
} 
Example 63
Source File: WordSpliter.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.nlp

import cn.piflow._
import cn.piflow.conf._
import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import com.huaban.analysis.jieba.JiebaSegmenter.SegMode
import com.huaban.analysis.jieba._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class WordSpliter extends ConfigurableStop {

  val authorEmail: String = "[email protected]"
  val description: String = "Word segmentation"
  val inportList: List[String] = List(Port.AnyPort.toString)
  val outportList: List[String] = List(Port.DefaultPort.toString)

  var path:String = _


  val jiebaSegmenter = new JiebaSegmenter()
  var tokenARR:ArrayBuffer[String]=ArrayBuffer()

  def segmenter(str:String): Unit ={

    var strVar = str
    //delete symbol
    strVar = strVar.replaceAll( "[\\p{P}+~$`^=|<>~`$^+=|<>¥×+\\s]" , "");

    val tokens = jiebaSegmenter.process(strVar,SegMode.SEARCH).asScala

    for (token: SegToken <- tokens){

        tokenARR += token.word

    }
  }

  def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {

    val session: SparkSession = pec.get[SparkSession]()

    //read
    val strDF = session.read.text(path)

    //segmenter
    segmenter(strDF.head().getString(0))

    //write df
    val rows: List[Row] = tokenARR.map(each => {
      var arr:Array[String]=Array(each)
      val row: Row = Row.fromSeq(arr)
      row
    }).toList
    val rowRDD: RDD[Row] = session.sparkContext.makeRDD(rows)
    val schema: StructType = StructType(Array(
      StructField("words",StringType)
    ))
    val df: DataFrame = session.createDataFrame(rowRDD,schema)

    out.write(df)
  }

  def initialize(ctx: ProcessContext): Unit = {

  }

  def setProperties(map : Map[String, Any]) = {
    path = MapUtil.get(map,"path").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val path = new PropertyDescriptor().name("path").displayName("path").description("The path of text file").defaultValue("").required(true)
    descriptor = path :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/nlp/NLP.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.Alg_NLPGroup.toString)
  }

} 
Example 64
Source File: JsonUtil.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.util

import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.{Column, DataFrame, SQLContext, SparkSession}

import scala.collection.mutable.ArrayBuffer

object JsonUtil extends Serializable{


//  The tag you want to parse,If you want to open an array field,you have to write it like this:links_name(MasterField_ChildField)
  def ParserJsonDF(df:DataFrame,tag:String): DataFrame = {

    var openArrField:String=""
    var ArrSchame:String=""

    var tagARR: Array[String] = tag.split(",")
    var tagNew:String=""


    for(tt<-tagARR){

      if(tt.indexOf("_")> -1){
        //包含“.”
        val openField: Array[String] = tt.split("_")
        openArrField=openField(0)

        ArrSchame+=(openField(1)+",")
      }else{
        tagNew+=(tt+",")
      }
    }
    tagNew+=openArrField
    ArrSchame=ArrSchame.substring(0,ArrSchame.length-1)

    tagARR = tagNew.split(",")
    var FinalDF:DataFrame=df

    //如果用户选择返回字段
    var strings: Seq[Column] =tagNew.split(",").toSeq.map(p => new Column(p))

    if(tag.length>0){
      val df00 = FinalDF.select(strings : _*)
      FinalDF=df00
    }

    //如果用户选择打开的数组字段,并给出schame
    if(openArrField.length>0&&ArrSchame.length>0){

      val schames: Array[String] = ArrSchame.split(",")

      var selARR:ArrayBuffer[String]=ArrayBuffer()//分别取出已经打开的字段
      //遍历数组,封装到column对象中
      var coARR:ArrayBuffer[Column]=ArrayBuffer()//打开字段的select方法用
      val sss = tagNew.split(",")//打开字段后todf方法用
      var co: Column =null
      for(each<-tagARR){
        if(each==openArrField){
          co = explode(FinalDF(openArrField))
          for(x<-schames){

            selARR+=(openArrField+"."+x)
          }
        }else{
          selARR+=each
          co=FinalDF(each)
        }
        coARR+=co
      }
      println("###################")
      selARR.foreach(println(_))
      var selSEQ: Seq[Column] = selARR.toSeq.map(q => new Column(q))

      var df01: DataFrame = FinalDF.select(coARR : _*).toDF(sss:_*)
      FinalDF = df01.select(selSEQ : _*)

    }

FinalDF

  }
} 
Example 65
Source File: BufferListener.scala    From Binding.scala   with MIT License 5 votes vote down vote up
package com.thoughtworks.binding

import Binding.{PatchedEvent, ChangedEvent, PatchedListener, ChangedListener}
import com.thoughtworks.binding.Binding.{PatchedEvent, ChangedEvent, PatchedListener, ChangedListener}

import scala.collection.mutable.ArrayBuffer


final class BufferListener extends ArrayBuffer[Any] {
  val listener = new ChangedListener[Seq[Any]] with PatchedListener[Any] {
    override def changed(event: ChangedEvent[Seq[Any]]): Unit = {
      BufferListener.this += event
    }

    override def patched(event: PatchedEvent[Any]): Unit = {
      BufferListener.this += event
    }
  }
} 
Example 66
Source File: FlatMapRemove.scala    From Binding.scala   with MIT License 5 votes vote down vote up
package com.thoughtworks.binding.regression

import com.thoughtworks.binding.Binding._
import com.thoughtworks.binding._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.collection.mutable.ArrayBuffer


final class FlatMapRemove extends AnyFreeSpec with Matchers {
  "removed source of a flatMap" in {

    val data = Vars.empty[Either[String, String]]

    val left = for {
      s <- data
      if s.isLeft
    } yield s

    val events = ArrayBuffer.empty[String]
    val autoPrint = Binding {
      if (left.length.bind > 0) {
        events += "has left"
      } else {
        events += "does not has left"
      }
    }
    assert(events.forall(_ == "does not has left"))
    autoPrint.watch()
    assert(events.forall(_ == "does not has left"))
    data.value += Right("1")
    assert(events.forall(_ == "does not has left"))
    data.value += Right("2")
    assert(events.forall(_ == "does not has left"))
    data.value += Right("3")
    assert(events.forall(_ == "does not has left"))
    data.value(1) = Left("left 2")
    assert(events.last == "has left")
    data.value --= Seq(Left("left 2"))
    assert(events.last == "does not has left")
  }
} 
Example 67
Source File: InsertThenClear.scala    From Binding.scala   with MIT License 5 votes vote down vote up
package com.thoughtworks.binding.regression

import com.thoughtworks.binding.Binding._
import com.thoughtworks.binding._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.collection.mutable.ArrayBuffer


final class InsertThenClear extends AnyFreeSpec with Matchers {
  "insert then clear" in {
    val items = Vars(1 to 10: _*)

    val mapped = items.map(-_)
    mapped.watch()
    assert(mapped.get sameElements Seq(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10))

    items.value.insertAll(3, 100 to 103)
    assert(mapped.get sameElements Seq(-1, -2, -3, -100, -101, -102, -103, -4, -5, -6, -7, -8, -9, -10))

    items.value.clear()
    assert(mapped.get sameElements Seq.empty)
  }
} 
Example 68
Source File: ProxyMessageHandler.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.stub

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels._

import com.basho.riak.client.core.RiakMessage
import com.basho.riak.client.core.util.HostAndPort
import shaded.com.basho.riak.protobuf.RiakKvPB
import shaded.com.basho.riak.protobuf.RiakMessageCodes._
import shaded.com.google.protobuf.ByteString

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer


class ProxyMessageHandler(hostAndPort: HostAndPort) extends RiakMessageHandler {

  private final val riakAddress = new InetSocketAddress(hostAndPort.getHost, hostAndPort.getPort)

  override def handle(context: ClientHandler.Context, input: RiakMessage): Iterable[RiakMessage] = input.getCode match {
    // coverage plan received from real Riak node must be modified to replace real node's host and port with proxy
    case MSG_CoverageReq => forwardAndTransform(context, input) { output =>
      val resp = RiakKvPB.RpbCoverageResp.parseFrom(output.getData)
      val modified = RiakKvPB.RpbCoverageResp.newBuilder(resp)
        .clearEntries()
        .addAllEntries(resp.getEntriesList.map { ce =>
          val ceBuilder = RiakKvPB.RpbCoverageEntry.newBuilder(ce)
          if (ce.getIp.toStringUtf8 == hostAndPort.getHost && ce.getPort == hostAndPort.getPort) {
            val localAddress = context.channel.asInstanceOf[NetworkChannel]
              .getLocalAddress.asInstanceOf[InetSocketAddress]
            ceBuilder.setIp(ByteString.copyFromUtf8(localAddress.getHostString))
            ceBuilder.setPort(localAddress.getPort)
          }
          ceBuilder.build()
        }).build()
      new RiakMessage(output.getCode, modified.toByteArray)
    }
    case _ => forwardMessage(context, input)
  }

  private def forwardMessage(context: ClientHandler.Context, input: RiakMessage): Iterable[RiakMessage] = {
    def readRiakResponse(channel: SocketChannel, out: List[RiakMessage] = Nil): Iterable[RiakMessage] = out match {
      case _ if !isDoneReceived(out, input) => readRiakResponse(channel, out ++ readSocket(channel))
      case _ => out
    }

    val channel = SocketChannel.open(riakAddress)
    try {
      // forward request to real Riak node
      assert(channel.write(RiakMessageEncoder.encode(input)) > 0)

      // read response for forwarded request from real Riak node
      readRiakResponse(channel)
    } finally {
      channel.close()
    }
  }

  private def readSocket(channel: SocketChannel): Iterable[RiakMessage] = {
    var accumulator = ByteBuffer.allocateDirect(0)

    var out = ArrayBuffer[RiakMessage]()
    while (out.isEmpty || accumulator.hasRemaining) {
      // try to parse riak message from bytes in accumulator buffer
      RiakMessageEncoder.decode(accumulator) match {
        case Some(x) =>
          accumulator = accumulator.slice()
          out += x
        case None =>
          // read next chunk of data from channel and add it into accumulator
          val in = ByteBuffer.allocateDirect(1024) // scalastyle:ignore
          channel.read(in)
          accumulator = ByteBuffer
            .allocate(accumulator.rewind().limit() + in.flip().limit())
            .put(accumulator)
            .put(in)
          accumulator.rewind()
          in.clear()
      }
    }
    out
  }

  private def isDoneReceived(out: Iterable[RiakMessage], input: RiakMessage): Boolean = input.getCode match {
    case MSG_IndexReq => out.foldLeft[Boolean](false)((a, m) => a || RiakKvPB.RpbIndexResp.parseFrom(m.getData).getDone)
    case _ => out.nonEmpty
  }

  private def forwardAndTransform(context: ClientHandler.Context, input: RiakMessage
                                 )(transform: RiakMessage => RiakMessage
                                 ): Iterable[RiakMessage] = forwardMessage(context, input).map(transform(_))

  override def onRespond(input: RiakMessage, output: Iterable[RiakMessage]): Unit = {}
} 
Example 69
Source File: QueryBucketKeys.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.spark.query

import com.basho.riak.client.core.query.Location
import com.basho.riak.spark.rdd.connector.RiakConnector
import com.basho.riak.spark.rdd.{BucketDef, ReadConf}

import scala.collection.mutable.ArrayBuffer

private case class QueryBucketKeys(bucket: BucketDef,
                                   readConf:ReadConf,
                                   riakConnector: RiakConnector,
                                   keys: Iterable[String]
                                  ) extends QuerySubsetOfKeys[String] {

  override def locationsByKeys(keys: Iterator[String]): (Boolean, Iterable[Location]) = {

    val dataBuffer = new ArrayBuffer[Location](readConf.fetchSize)

    val ns = bucket.asNamespace()

    keys.forall(k =>{
      dataBuffer += new Location(ns, k)
      dataBuffer.size < readConf.fetchSize} )
    false -> dataBuffer
  }
} 
Example 70
Source File: Query2iKeys.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.spark.query

import com.basho.riak.client.core.operations.CoveragePlanOperation.Response.CoverageEntry
import com.basho.riak.client.core.query.Location
import com.basho.riak.spark.rdd.connector.RiakConnector
import com.basho.riak.spark.rdd.{BucketDef, ReadConf}

import scala.collection.mutable.ArrayBuffer

private case class Query2iKeys[K](bucket: BucketDef,
                                  readConf:ReadConf,
                                  riakConnector: RiakConnector,
                                  index: String,
                                  keys: Iterable[K]
                                 ) extends QuerySubsetOfKeys[K] {
  private var query2iKey: Option[Query2iKeySingleOrRange[K]] = None
  private var tokenNext: Option[Either[String, CoverageEntry]] = None

  // By default there should be an empty Serializable Iterator
  private var _iterator: Iterator[Location] = ArrayBuffer.empty[Location].iterator

  private def chunkIsCollected(chunk: Iterable[Location]) = chunk.size >= readConf.fetchSize

  // scalastyle:off cyclomatic.complexity
  override def locationsByKeys(keys: Iterator[K]): (Boolean, Iterable[Location]) = {
    val dataBuffer = new ArrayBuffer[Location](readConf.fetchSize)

    while ((keys.hasNext || _iterator.hasNext || tokenNext.isDefined) && !chunkIsCollected(dataBuffer)){
      // Previously gathered results should be returned at first, if any
      _iterator forall  ( location => {
        dataBuffer += location
        !chunkIsCollected(dataBuffer)
      })

      if(!chunkIsCollected(dataBuffer)) tokenNext match {
        case Some(next) =>
          // Fetch the next results page from the previously executed 2i query, if any
          assert(query2iKey.isDefined)

          val r = query2iKey.get.nextLocationChunk(tokenNext)
          tokenNext = r._1
          _iterator = r._2.iterator

        case None if keys.hasNext =>
          // query data for the first/next key
          assert(_iterator.isEmpty && tokenNext.isEmpty)

          val key = keys.next()
          query2iKey = Some(new Query2iKeySingleOrRange[K](bucket, readConf, riakConnector, index, key))

          val r = query2iKey.get.nextLocationChunk(tokenNext)
          tokenNext = r._1
          _iterator = r._2.iterator

        case _ => // There is nothing to do
      }
    }
    tokenNext.isDefined -> dataBuffer
  }
  // scalastyle:on cyclomatic.complexity
} 
Example 71
Source File: Partitioner.scala    From spark-solr   with Apache License 2.0 5 votes vote down vote up
package com.lucidworks.spark

import java.net.InetAddress

import com.lucidworks.spark.rdd.SolrRDD
import com.lucidworks.spark.util.SolrSupport
import org.apache.solr.client.solrj.SolrQuery
import org.apache.spark.Partition

import scala.collection.mutable.ArrayBuffer

// Is there a need to override {@code Partitioner.scala} and define our own partition id's
object SolrPartitioner {

  def getShardPartitions(shards: List[SolrShard], query: SolrQuery) : Array[Partition] = {
    shards.zipWithIndex.map{ case (shard, i) =>
      // Chose any of the replicas as the active shard to query
      SelectSolrRDDPartition(i, "*", shard, query, SolrRDD.randomReplica(shard))}.toArray
  }

  def getSplitPartitions(
      shards: List[SolrShard],
      query: SolrQuery,
      splitFieldName: String,
      splitsPerShard: Int): Array[Partition] = {
    var splitPartitions = ArrayBuffer.empty[SelectSolrRDDPartition]
    var counter = 0
    shards.foreach(shard => {
      val splits = SolrSupport.getShardSplits(query, shard, splitFieldName, splitsPerShard)
      splits.foreach(split => {
        splitPartitions += SelectSolrRDDPartition(counter, "*", shard, split.query, split.replica)
        counter = counter + 1
      })
    })
    splitPartitions.toArray
  }

  // Workaround for SOLR-10490. TODO: Remove once fixed
  def getExportHandlerPartitions(
      shards: List[SolrShard],
      query: SolrQuery): Array[Partition] = {
    shards.zipWithIndex.map{ case (shard, i) =>
      // Chose any of the replicas as the active shard to query
      ExportHandlerPartition(i, shard, query, SolrRDD.randomReplica(shard), 0, 0)}.toArray
  }

  // Workaround for SOLR-10490. TODO: Remove once fixed
  def getExportHandlerPartitions(
      shards: List[SolrShard],
      query: SolrQuery,
      splitFieldName: String,
      splitsPerShard: Int): Array[Partition] = {
    val splitPartitions = ArrayBuffer.empty[ExportHandlerPartition]
    var counter = 0
    shards.foreach(shard => {
      // Form a continuous iterator list so that we can pick different replicas for different partitions in round-robin mode
      val splits = SolrSupport.getExportHandlerSplits(query, shard, splitFieldName, splitsPerShard)
      splits.foreach(split => {
        splitPartitions += ExportHandlerPartition(counter, shard, split.query, split.replica, split.numWorkers, split.workerId)
        counter = counter+1
      })
    })
    splitPartitions.toArray
  }

}

case class SolrShard(shardName: String, replicas: List[SolrReplica])

case class SolrReplica(
    replicaNumber: Int,
    replicaName: String,
    replicaUrl: String,
    replicaHostName: String,
    locations: Array[InetAddress]) {
  def getHostAndPort(): String = {replicaHostName.substring(0, replicaHostName.indexOf('_'))}
  override def toString(): String = {
    return s"SolrReplica(${replicaNumber}) ${replicaName}: url=${replicaUrl}, hostName=${replicaHostName}, locations="+locations.mkString(",")
  }
} 
Example 72
Source File: GranularBigVector.scala    From glint   with MIT License 5 votes vote down vote up
package glint.models.client.granular

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag

import glint.models.client.BigVector


  override def push(keys: Array[Long], values: Array[V])
    (implicit ec: ExecutionContext): Future[Boolean] = {
    var i = 0
    val ab = new ArrayBuffer[Future[Boolean]](keys.length / maximumMessageSize)
    while (i < keys.length) {
      val end = Math.min(keys.length, i + maximumMessageSize)
      val future = underlying.push(keys.slice(i, end), values.slice(i, end))
      ab.append(future)
      i += maximumMessageSize
    }
    Future.sequence(ab.toIterator).transform(x => x.forall(y => y), err => err)
  }

} 
Example 73
Source File: GranularBigMatrix.scala    From glint   with MIT License 5 votes vote down vote up
package glint.models.client.granular

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag

import breeze.linalg.Vector
import glint.models.client.BigMatrix


  override def pull(rows: Array[Long],
                    cols: Array[Int])(implicit ec: ExecutionContext): Future[Array[V]] = {
    if (rows.length <= maximumMessageSize) {
      underlying.pull(rows, cols)
    } else {
      var i = 0
      val ab = new ArrayBuffer[Future[Array[V]]](rows.length / maximumMessageSize)
      while (i < rows.length) {
        val end = Math.min(rows.length, i + maximumMessageSize)
        val future = underlying.pull(rows.slice(i, end), cols.slice(i, end))
        ab.append(future)
        i += maximumMessageSize
      }
      Future.sequence(ab.toIterator).map {
        case arrayOfValues =>
          val finalValues = new ArrayBuffer[V](rows.length)
          arrayOfValues.foreach(x => finalValues.appendAll(x))
          finalValues.toArray
      }
    }
  }
} 
Example 74
Source File: HiveQLProcessBuilder.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.enginemanager.hive.process

import java.nio.file.Paths

import com.webank.wedatasphere.linkis.common.conf.Configuration
import com.webank.wedatasphere.linkis.enginemanager.conf.EnvConfiguration.{DEFAULT_JAVA_OPTS, JAVA_HOME, engineGCLogPath}
import com.webank.wedatasphere.linkis.enginemanager.hive.conf.HiveEngineConfiguration
import com.webank.wedatasphere.linkis.enginemanager.impl.UserEngineResource
import com.webank.wedatasphere.linkis.enginemanager.process.JavaProcessEngineBuilder
import com.webank.wedatasphere.linkis.enginemanager.{AbstractEngineCreator, EngineResource}
import com.webank.wedatasphere.linkis.protocol.engine.RequestEngine
import org.apache.commons.lang.StringUtils
import org.slf4j.LoggerFactory

import scala.collection.mutable.ArrayBuffer


  override protected def classpathCheck(jarOrFiles: Array[String]): Unit = {
    for(jarOrFile <- jarOrFiles){
      checkJarOrFile(jarOrFile)
    }
  }
  //todo Check the jar of the classpath(对classpath的jar进行检查)
  private def checkJarOrFile(jarOrFile:String):Unit = {

  }


  override def build(engineRequest: EngineResource, request: RequestEngine): Unit = {
    this.request = request
    userEngineResource = engineRequest.asInstanceOf[UserEngineResource]
    val javaHome = JAVA_HOME.getValue(request.properties)
    if(StringUtils.isEmpty(javaHome)) {
      warn("We cannot find the java home, use java to run storage repl web server.")
      commandLine += "java"
    } else {
      commandLine += Paths.get(javaHome, "bin/java").toAbsolutePath.toFile.getAbsolutePath
    }
    if (request.properties.containsKey(HiveEngineConfiguration.HIVE_CLIENT_MEMORY.key)){
      val settingClientMemory = request.properties.get(HiveEngineConfiguration.HIVE_CLIENT_MEMORY.key)
      if (!settingClientMemory.toLowerCase().endsWith("g")){
        request.properties.put(HiveEngineConfiguration.HIVE_CLIENT_MEMORY.key, settingClientMemory + "g")
      }
      //request.properties.put(HiveEngineConfiguration.HIVE_CLIENT_MEMORY.key, request.properties.get(HiveEngineConfiguration.HIVE_CLIENT_MEMORY.key)+"g")
    }
    val clientMemory = HiveEngineConfiguration.HIVE_CLIENT_MEMORY.getValue(request.properties).toString
    if (clientMemory.toLowerCase().endsWith("g")){
      commandLine += ("-Xmx" + clientMemory.toLowerCase())
      commandLine += ("-Xms" + clientMemory.toLowerCase())
    }else{
      commandLine += ("-Xmx" + clientMemory + "g")
      commandLine += ("-Xms" + clientMemory + "g")
    }
    val javaOPTS = getExtractJavaOpts
    val alias = getAlias(request)
    if(StringUtils.isNotEmpty(DEFAULT_JAVA_OPTS.getValue))
      DEFAULT_JAVA_OPTS.getValue.format(engineGCLogPath(port, userEngineResource.getUser, alias)).split("\\s+").foreach(commandLine += _)
    if(StringUtils.isNotEmpty(javaOPTS)) javaOPTS.split("\\s+").foreach(commandLine += _)
    //engineLogJavaOpts(port, alias).trim.split(" ").foreach(commandLine += _)
    if(Configuration.IS_TEST_MODE.getValue) {
      val port = AbstractEngineCreator.getNewPort
      info(s"$toString open debug mode with port $port.")
      commandLine += s"-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$port"
    }
    var classpath = getClasspath(request.properties, getExtractClasspath)
    classpath = classpath ++ request.properties.get("jars").split(",")
    classpathCheck(classpath)
    commandLine += "-Djava.library.path=/appcom/Install/hadoop/lib/native"
    commandLine += "-cp"
    commandLine += classpath.mkString(":")
    commandLine += "com.webank.wedatasphere.linkis.engine.DataWorkCloudEngineApplication"
  }


//  override def build(engineRequest: EngineResource, request: RequestEngine): Unit = {
//    import scala.collection.JavaConversions._
//    request.properties foreach {case (k, v) => LOG.info(s"request key is $k, value is $v")}
//    this.request = request
//    super.build(engineRequest, request)
//
//  }

  override protected val addApacheConfigPath: Boolean = true
} 
Example 75
Source File: JDBCSQLCodeParser.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.entrance.executer

import com.webank.wedatasphere.linkis.entrance.conf.JDBCConfiguration
import org.apache.commons.lang.StringUtils

import scala.collection.mutable.ArrayBuffer

object JDBCSQLCodeParser {

  val separator = ";"
  val defaultLimit: Int = JDBCConfiguration.ENGINE_DEFAULT_LIMIT.getValue

  def parse(code: String): Array[String] = {
    val codeBuffer = new ArrayBuffer[String]()

    def appendStatement(sqlStatement: String): Unit = {
      codeBuffer.append(sqlStatement)
    }

    if (StringUtils.contains(code, separator)) {
      StringUtils.split(code, ";").foreach {
        case s if StringUtils.isBlank(s) =>
        case s if isSelectCmdNoLimit(s) => appendStatement(s + " limit " + defaultLimit);
        case s => appendStatement(s);
      }
    } else {
      code match {
        case s if StringUtils.isBlank(s) =>
        case s if isSelectCmdNoLimit(s) => appendStatement(s + " limit " + defaultLimit);
        case s => appendStatement(s);
      }
    }
    codeBuffer.toArray
  }

  def isSelectCmdNoLimit(cmd: String): Boolean = {
    var code = cmd.trim
    if (!cmd.split("\\s+")(0).equalsIgnoreCase("select")) return false
    if (code.contains("limit")) code = code.substring(code.lastIndexOf("limit")).trim
    else if (code.contains("LIMIT")) code = code.substring(code.lastIndexOf("LIMIT")).trim.toLowerCase
    else return true
    val hasLimit = code.matches("limit\\s+\\d+\\s*;?")
    if (hasLimit) {
      if (code.indexOf(";") > 0) code = code.substring(5, code.length - 1).trim
      else code = code.substring(5).trim
      val limitNum = code.toInt
      if (limitNum > defaultLimit) throw new IllegalArgumentException("We at most allowed to limit " + defaultLimit + ", but your SQL has been over the max rows.")
    }
    !hasLimit
  }


} 
Example 76
Source File: PythonEngineExecutor.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.executors

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.PythonSession
import com.webank.wedatasphere.linkis.engine.exception.EngineException
import com.webank.wedatasphere.linkis.engine.execute.{EngineExecutor, EngineExecutorContext}
import com.webank.wedatasphere.linkis.engine.rs.RsOutputStream
import com.webank.wedatasphere.linkis.protocol.engine.JobProgressInfo
import com.webank.wedatasphere.linkis.resourcemanager.{LoadInstanceResource, Resource}
import com.webank.wedatasphere.linkis.rpc.Sender
import com.webank.wedatasphere.linkis.scheduler.executer._
import org.apache.commons.io.IOUtils

import scala.collection.mutable.ArrayBuffer


class PythonEngineExecutor(outputPrintLimit: Int) extends EngineExecutor(outputPrintLimit, false) with SingleTaskOperateSupport with SingleTaskInfoSupport with Logging {
  override def getName: String = Sender.getThisServiceInstance.getInstance
  private val lineOutputStream = new RsOutputStream
  private[executors] var engineExecutorContext: EngineExecutorContext = _
  override def getActualUsedResources: Resource = {
    new LoadInstanceResource(Runtime.getRuntime.totalMemory() - Runtime.getRuntime.freeMemory(), 2, 1)
  }

 private val pySession = new PythonSession

  override protected def executeLine(engineExecutorContext: EngineExecutorContext, code: String): ExecuteResponse = {
    if(engineExecutorContext != this.engineExecutorContext){
      this.engineExecutorContext = engineExecutorContext
      pySession.setEngineExecutorContext(engineExecutorContext)
      //lineOutputStream.reset(engineExecutorContext)
      info("Python executor reset new engineExecutorContext!")
    }
    engineExecutorContext.appendStdout(s"$getName >> ${code.trim}")
    pySession.execute(code)
    //lineOutputStream.flush()
   SuccessExecuteResponse()
  }

  override protected def executeCompletely(engineExecutorContext: EngineExecutorContext, code: String, completedLine: String): ExecuteResponse = {
    val newcode = completedLine + code
    info("newcode is " + newcode)
    executeLine(engineExecutorContext, newcode)
  }

  override def kill(): Boolean = true

  override def pause(): Boolean = true

  override def resume(): Boolean = true

  override def progress(): Float = {
    if (this.engineExecutorContext != null){
      this.engineExecutorContext.getCurrentParagraph / this.engineExecutorContext.getTotalParagraph.asInstanceOf[Float]
    }else 0.0f
  }

  override def getProgressInfo: Array[JobProgressInfo] = {
    val jobProgressInfos = new ArrayBuffer[JobProgressInfo]()
    jobProgressInfos.toArray
    Array.empty
  }

  override def log(): String = ""

  override def close(): Unit = {
    IOUtils.closeQuietly(lineOutputStream)
    var isKill:Boolean = false
    try {
      pySession.close
      isKill = true;
    } catch {
      case e: Throwable =>
        throw new EngineException(60004, "Engine shutdown exception(引擎关闭异常)")
    }
  }
} 
Example 77
Source File: SparkPostExecutionHook.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.extension

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext
import com.webank.wedatasphere.linkis.scheduler.executer.ExecuteResponse

import scala.collection.mutable.ArrayBuffer


trait SparkPostExecutionHook {
  def hookName:String
  def callPostExecutionHook(engineExecutorContext: EngineExecutorContext, executeResponse: ExecuteResponse, code: String): Unit
}

object SparkPostExecutionHook extends Logging{
  private val postHooks = ArrayBuffer[SparkPostExecutionHook]()

  def register(postExecutionHook: SparkPostExecutionHook):Unit = {
    info(s"Get a postExecutionHook of ${postExecutionHook.hookName} register")
    postHooks.append(postExecutionHook)
  }

  def getSparkPostExecutionHooks():Array[SparkPostExecutionHook] = {
    postHooks.toArray
  }
} 
Example 78
Source File: SparkPreExecutionHook.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.extension

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext

import scala.collection.mutable.ArrayBuffer


trait SparkPreExecutionHook {
  def hookName:String
  def callPreExecutionHook(engineExecutorContext: EngineExecutorContext, code: String): String
}

object SparkPreExecutionHook extends Logging{
  private val preHooks = ArrayBuffer[SparkPreExecutionHook]()

  def register(preExecutionHook: SparkPreExecutionHook):Unit = {
    info(s"Get a preExecutionHook of ${preExecutionHook.hookName} register")
    preHooks.append(preExecutionHook)
  }

  def getSparkPreExecutionHooks():Array[SparkPreExecutionHook] = {
    preHooks.toArray
  }
} 
Example 79
Source File: SparkSqlExtension.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.extension

import java.util.concurrent._

import com.webank.wedatasphere.linkis.common.conf.CommonVars
import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

abstract class SparkSqlExtension extends Logging{

  private val maxPoolSize = CommonVars("wds.linkis.dws.ujes.spark.extension.max.pool",5).getValue

  private  val executor = new ThreadPoolExecutor(2, maxPoolSize, 2, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable](), new ThreadFactory {
    override def newThread(r: Runnable): Thread = {
      val thread = new Thread(r)
      thread.setDaemon(true)
      thread
    }
  })

  final def afterExecutingSQL(sqlContext: SQLContext,command: String,dataFrame: DataFrame,timeout:Long,sqlStartTime:Long):Unit = {
    try {
      val thread = new Runnable {
        override def run(): Unit = extensionRule(sqlContext,command,dataFrame.queryExecution,sqlStartTime)
      }
      val future = executor.submit(thread)
      Utils.waitUntil(future.isDone,timeout milliseconds)
    } catch {
      case e: Throwable => info("Failed to execute SparkSqlExtension: ", e)
    }
  }

  protected def extensionRule(sqlContext: SQLContext,command: String,queryExecution: QueryExecution,sqlStartTime:Long):Unit


}

object SparkSqlExtension extends Logging {

  private val extensions = ArrayBuffer[SparkSqlExtension]()

  def register(sqlExtension: SparkSqlExtension):Unit = {
    info("Get a sqlExtension register")
    extensions.append(sqlExtension)
  }

  def getSparkSqlExtensions():Array[SparkSqlExtension] = {
    extensions.toArray
  }
} 
Example 80
Source File: CSTableParser.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.cs

import java.util.regex.Pattern

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.cs.client.service.CSTableService
import com.webank.wedatasphere.linkis.cs.common.entity.metadata.CSTable
import com.webank.wedatasphere.linkis.cs.common.utils.CSCommonUtils
import com.webank.wedatasphere.linkis.engine.exception.ExecuteError
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext
import org.apache.commons.lang.StringUtils
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.csv.DolphinToSpark

import scala.collection.mutable.ArrayBuffer


  def getCSTable(csTempTable:String,  contextIDValueStr: String, nodeNameStr: String):CSTable = {
    CSTableService.getInstance().getUpstreamSuitableTable(contextIDValueStr, nodeNameStr, csTempTable)
  }

  def registerTempTable(csTable: CSTable):Unit = {
    val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
    info(s"Start to create  tempView to sparkSession viewName(${csTable.getName}) location(${csTable.getLocation})")
    DolphinToSpark.createTempView(spark, csTable.getName, csTable.getLocation, true)
    info(s"Finished to create  tempView to sparkSession viewName(${csTable.getName}) location(${csTable.getLocation})")
  }
} 
Example 81
Source File: LogContainer.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.spark.common

import scala.collection.Iterable
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer


class LogContainer(val logSize: Int) {

  private final val logs = new Array[String](logSize)
  private var flag, tail = 0

  def putLog(log: String): Unit = {
    logs.synchronized {
      val index = (tail + 1) % logSize
      if(index == flag) {
        flag = (flag + 1) % logSize
      }
      logs(tail) = log
      tail = index
    }
  }

  def putLogs(logs: Iterable[String]) = synchronized {
    logs.foreach(putLog)
  }

  def reset() = synchronized {
    flag = 0
    tail = 0
  }

  def getLogs: List[String] = {
    logs.synchronized {
      if(flag == tail) {
        return List.empty[String]
      }
      val _logs = ArrayBuffer[String]()
      val _tail = if(flag > tail) tail + logSize else tail
      for (index <- flag until _tail) {
        val _index = index % logSize
        _logs += logs(_index)
      }
      flag = tail
      _logs.toList
    }
  }

  def size = {
    if(flag == tail) 0
    else if(flag > tail) tail + logSize - flag
    else tail - flag
  }

  def getLogList: java.util.List[String] = getLogs

} 
Example 82
Source File: SparkConfiguration.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.enginemanager.configuration

import com.webank.wedatasphere.linkis.common.conf.{CommonVars, Configuration}
import com.webank.wedatasphere.linkis.common.utils.{ClassUtils, Logging}
import com.webank.wedatasphere.linkis.engine.factory.SparkEngineExecutorFactory
import com.webank.wedatasphere.linkis.enginemanager.AbstractEngineCreator

import scala.collection.mutable.ArrayBuffer

object SparkConfiguration extends Logging {
  val SPARK_MAX_PARALLELISM_USERS = CommonVars[Int]("wds.linkis.engine.spark.user.parallelism", 100)
  val SPARK_USER_MAX_WAITING_SIZE = CommonVars[Int]("wds.linkis.engine.spark.user.waiting.max", 100)

  val SPARK_SESSION_HOOK = CommonVars[String]("wds.linkis.engine.spark.session.hook", "")
  val SPARK_LANGUAGE_REPL_INIT_TIME = CommonVars[String]("wds.linkis.engine.spark.language-repl.init.time", new String("30s"))

  val SPARK_ALLOW_REQUEST_ALL_YARN_MEMORY = CommonVars[String]("wds.linkis.engine.spark.allow.all-memory.when.queue", new String("60g"))
  val SPARK_ALLOW_REQUEST_ALL_YARN_CORES = CommonVars[Int]("wds.linkis.engine.spark.allow.all-cores.when.queue", 30)
  val SPARK_USER_MAX_ALLOCATE_SESSIONS = CommonVars[Int]("wds.linkis.engine.spark.user.sessions.max", 5)
  val SPARK_USER_MAX_ALLOCATE_YARN_MEMORY = CommonVars[String]("wds.linkis.engine.spark.user.yarn.memory.max", new String("100g"))
  val SPARK_USER_MAX_ALLOCATE_YARN_CORES = CommonVars[Int]("wds.linkis.engine.spark.user.cores.max", 50)
  val SPARK_USER_MAX_ALLOCATE_DRIVER_MEMORY = CommonVars[String]("wds.linkis.engine.spark.user.driver.memory.max", new String("15g"))
  val SPARK_USER_MAX_ALLOCATE_DRIVER_CORES = SPARK_USER_MAX_ALLOCATE_SESSIONS
  val SPARK_USER_MAX_RESOURCE_IN_QUEUE = CommonVars[Float]("wds.linkis.engine.spark.user.queue.resources.max", 0.6f)
  val SPARK_DANGER_QUEUE_USED_CAPACITY = CommonVars[Float]("wds.linkis.engine.spark.danger.queue.used", 0.2f)
  val SPARK_DANGER_QUEUE_USER_ALLOCATE_SESSION = CommonVars[Int]("wds.linkis.engine.spark.danger.user.sessions.max", 2)
  val SPARK_WARN_QUEUE_USED_CAPACITY = CommonVars[Float]("wds.linkis.engine.spark.warning.queue.used", 0.5f)
  val SPARK_WARN_QUEUE_USER_ALLOCATE_SESSION = CommonVars[Int]("wds.linkis.engine.spark.warning.user.sessions.max", 3)

  val PROXY_USER = CommonVars[String]("spark.proxy.user", "${UM}")
  val SPARK_CLIENT_MODE = "client"
  val SPARK_CLUSTER_MODE = "cluster"
  val SPARK_DEPLOY_MODE = CommonVars[String]("spark.submit.deployMode", SPARK_CLIENT_MODE)
  val SPARK_APPLICATION_JARS = CommonVars[String]("spark.application.jars", "", "User-defined jars, separated by English, must be uploaded to HDFS first, and must be full path to HDFS.(用户自定义jar包,多个以英文,隔开,必须先上传到HDFS,且需为HDFS全路径。)")

  val SPARK_EXTRA_JARS = CommonVars[String]("spark.jars", "", "Additional jar package, Driver and Executor take effect(额外的jar包,Driver和Executor生效)")

  val MAPRED_OUTPUT_COMPRESS = CommonVars[String]("mapred.output.compress", "true", "Whether the map output is compressed(map输出结果是否压缩)")
  val MAPRED_OUTPUT_COMPRESSION_CODEC = CommonVars[String]("mapred.output.compression.codec", "org.apache.hadoop.io.compress.GzipCodec", "Map output compression method(map输出结果压缩方式)")
  val SPARK_MASTER = CommonVars[String]("spark.master", "yarn", "Default master(默认master)")
  val SPARK_OUTPUTDIR = CommonVars[String]("spark.outputDir", "/home/georgeqiao", "Default output path(默认输出路径)")

  val DWC_SPARK_USEHIVECONTEXT = CommonVars[Boolean]("wds.linkis.spark.useHiveContext", true)
  val ENGINE_JAR = CommonVars[String]("wds.linkis.enginemanager.core.jar", ClassUtils.jarOfClass(classOf[SparkEngineExecutorFactory]).head)
  val SPARK_DRIVER_CLASSPATH = CommonVars[String]("wds.linkis.spark.driver.conf.mainjar", "")
  val SPARK_DRIVER_EXTRA_JAVA_OPTIONS = CommonVars[String]("spark.driver.extraJavaOptions", "\"-Dwds.linkis.configuration=linkis-engine.properties " + getJavaRemotePort + "\"")
  val DEFAULT_JAVA_OPTS = CommonVars[String]("wds.linkis.engine.javaOpts.default", "-server -XX:+UseG1GC -XX:MaxPermSize=250m -XX:PermSize=128m " +
    "-Xloggc:%s -XX:+PrintGCDetails -XX:+PrintGCTimeStamps -XX:+PrintGCDateStamps -Dwds.linkis.configuration=linkis-engine.properties")
  val SPARK_ML_BUCKET_FIELDS = CommonVars[String]("wds.linkis.engine.spark.ml.bucketFields", "age[0,18,30,60,100]")

  val SPARK_SUBMIT_CMD = CommonVars[String]("wds.linkis.engine.spark.submit.cmd", "spark-submit")
  private var Ports: ArrayBuffer[Int] = _

  def getJavaRemotePort = {
    if (Configuration.IS_TEST_MODE.getValue) {
      val r = new scala.util.Random()
      val port = 1024 + r.nextInt((65536 - 1024) + 1)
      info(s"open debug mode with port $port.")
      s"-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$port"
    } else {
      ""
    }
  }

  private def getAvailablePort: Int = synchronized {
    var port = AbstractEngineCreator.getNewPort
    info("Get new port " + port)
    if (Ports == null) {
      info("Get inInitPorts is null ")
      Ports = ArrayBuffer(0, 1)
      info("Current ports is " + Ports.toList.toString())
    }
    while (Ports.contains(port)) {
      if (AbstractEngineCreator != null) {
        port = AbstractEngineCreator.getNewPort
      }
    }
    Ports += port
    port
  }
} 
Example 83
Source File: CSResourceParser.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.cs


import java.util
import java.util.regex.Pattern

import com.webank.wedatasphere.linkis.cs.client.service.CSResourceService
import com.webank.wedatasphere.linkis.engine.PropertiesExecuteRequest
import org.apache.commons.lang.StringUtils

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer


class CSResourceParser {

  private val pb = Pattern.compile("cs://[^\\s\"]+[$\\s]{0,1}", Pattern.CASE_INSENSITIVE)

  private val PREFIX = "cs://"

  private def getPreFixResourceNames(code: String): Array[String] = {
    val bmlResourceNames = new ArrayBuffer[String]()
    val mb = pb.matcher(code)
    while (mb.find) bmlResourceNames.append(mb.group.trim)
    bmlResourceNames.toArray
  }

  def parse(executeRequest: PropertiesExecuteRequest, code: String, contextIDValueStr: String, nodeNameStr: String): String = {

    //TODO getBMLResource peaceWong
    val bmlResourceList = CSResourceService.getInstance().getUpstreamBMLResource(contextIDValueStr, nodeNameStr)

    val parsedResources = new util.ArrayList[util.Map[String, Object]]()
    val preFixResourceNames = getPreFixResourceNames(code)

    val preFixNames = new ArrayBuffer[String]()
    val parsedNames = new ArrayBuffer[String]()
    preFixResourceNames.foreach { preFixResourceName =>
      val resourceName = preFixResourceName.replace(PREFIX, "").trim
      val bmlResourceOption = bmlResourceList.find(_.getDownloadedFileName.equals(resourceName))
      if (bmlResourceOption.isDefined) {
        val bmlResource = bmlResourceOption.get
        val map = new util.HashMap[String, Object]()
        map.put("resourceId", bmlResource.getResourceId)
        map.put("version", bmlResource.getVersion)
        map.put("fileName", resourceName)
        parsedResources.add(map)
        preFixNames.append(preFixResourceName)
        parsedNames.append(resourceName)
      }

    }
    executeRequest.properties.put("resources", parsedResources)
    StringUtils.replaceEach(code, preFixNames.toArray, parsedNames.toArray)
  }

} 
Example 84
Source File: RsOutputStream.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.rs

import java.io.OutputStream

import com.webank.wedatasphere.linkis.common.io.resultset.ResultSetWriter
import com.webank.wedatasphere.linkis.common.io.{MetaData, Record}
import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext
import com.webank.wedatasphere.linkis.storage.LineRecord

import scala.collection.mutable.ArrayBuffer



class RsOutputStream extends OutputStream with Logging{
  private val line = ArrayBuffer[Byte]()
  private var isReady = false
  private var writer: ResultSetWriter[_ <: MetaData, _ <: Record] = _
  override def write(b: Int) = if(isReady) synchronized {
    if(writer != null) {
      if (b == '\n') {
        val outStr = new String(line.toArray,"UTF-8")
        writer.addRecord(new LineRecord(outStr))
        //info("output line:" + outStr)
        line.clear()
      } else line += b.toByte
    }else{
       warn("writer is null")
    }
  }

  def reset(engineExecutorContext: EngineExecutorContext) = {
    writer = engineExecutorContext.createDefaultResultSetWriter()
    writer.addMetaData(null)
  }

  def ready() = isReady = true

  override def flush(): Unit = if(writer != null && line.nonEmpty) {
    val outStr = new String(line.toArray,"UTF-8")
    writer.addRecord(new LineRecord(outStr))
    //info("flush line:" + outStr)
    line.clear()
  }

  override def toString = if(writer != null) writer.toString() else null

  override def close() = if(writer != null) {
    flush()
    writer.close()
    writer = null
  }
} 
Example 85
Source File: CodeGeneratorEngineHook.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.execute.hook

import java.io.File

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.execute.{EngineExecutor, EngineHook}
import com.webank.wedatasphere.linkis.scheduler.executer.{ExecuteRequest, RunTypeExecuteRequest}
import com.webank.wedatasphere.linkis.server.JMap
import org.apache.commons.io.FileUtils
import org.apache.commons.lang.StringUtils

import scala.collection.mutable.ArrayBuffer


@Deprecated
//changed to UdfLoadEngineHook
abstract class CodeGeneratorEngineHook extends EngineHook with Logging{ self =>
  val udfPathProp = "udf.paths"
  protected var creator: String = _
  protected var user: String = _
  protected var initSpecialCode: String = _
  protected val runType: String

  protected def acceptCodeType(line: String): Boolean

  protected def generateCode(): Array[String] = {
    val codeBuffer = new ArrayBuffer[String]
    val statementBuffer = new ArrayBuffer[String]
    var accept = true
    initSpecialCode.split("\n").foreach{
      case "" =>
      case l if l.startsWith("%") =>
        if(acceptCodeType(l)){
          accept = true
          codeBuffer.append(statementBuffer.mkString("\n"))
          statementBuffer.clear()
        }else{
          accept = false
        }
      case l if accept => statementBuffer.append(l)
      case _ =>
    }
    if(statementBuffer.nonEmpty) codeBuffer.append(statementBuffer.mkString("\n"))
    codeBuffer.toArray
  }

  override def beforeCreateEngine(params: JMap[String, String]): JMap[String, String] = {
    creator = params.get("creator")
    user = params.get("user")
    initSpecialCode = StringUtils.split(params.get(udfPathProp), ",").map(readFile).mkString("\n")
    params
  }

  override def afterCreatedEngine(executor: EngineExecutor): Unit = {
    generateCode().foreach {
      case "" =>
      case c: String =>
        info("Submit udf registration to engine, code: " + c)
        executor.execute(new ExecuteRequest with RunTypeExecuteRequest{
          override val code: String = c
          override val runType: String = self.runType
        })
        info("executed code: " + c)
    }
  }

  protected def readFile(path: String): String = {
    info("read file: " + path)
    val file = new File(path)
    if(file.exists()){
      FileUtils.readFileToString(file)
    } else {
      info("udf file: [" + path + "] doesn't exist, ignore it.")
      ""
    }
  }
}
@Deprecated
class SqlCodeGeneratorEngineHook extends CodeGeneratorEngineHook{
  override val runType = "sql"
  override protected def acceptCodeType(line: String): Boolean = {
    line.startsWith("%sql")
  }
}
@Deprecated
class PythonCodeGeneratorEngineHook extends CodeGeneratorEngineHook{
  override val runType = "python"
  override protected def acceptCodeType(line: String): Boolean = {
    line.startsWith("%python")
  }
}
@Deprecated
class ScalaCodeGeneratorEngineHook extends CodeGeneratorEngineHook{
  override val runType = "scala"
  override protected def acceptCodeType(line: String): Boolean = {
    line.startsWith("%scala")
  }
} 
Example 86
Source File: AbstractEngineCreator.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.enginemanager

import java.net.ServerSocket

import com.webank.wedatasphere.linkis.common.conf.DWCArgumentsParser
import com.webank.wedatasphere.linkis.common.utils.Utils
import com.webank.wedatasphere.linkis.enginemanager.conf.EngineManagerConfiguration
import com.webank.wedatasphere.linkis.enginemanager.exception.EngineManagerErrorException
import com.webank.wedatasphere.linkis.enginemanager.impl.UserTimeoutEngineResource
import com.webank.wedatasphere.linkis.enginemanager.process.{CommonProcessEngine, ProcessEngine, ProcessEngineBuilder}
import com.webank.wedatasphere.linkis.protocol.engine.{EngineCallback, RequestEngine}
import com.webank.wedatasphere.linkis.rpc.Sender
import com.webank.wedatasphere.linkis.server.{JMap, toScalaMap}
import org.apache.commons.io.IOUtils

import scala.collection.mutable.ArrayBuffer


abstract class AbstractEngineCreator extends EngineCreator {

  private val inInitPorts = ArrayBuffer[Int]()

  private def getAvailablePort: Int = synchronized {
    var port = AbstractEngineCreator.getNewPort
    while(inInitPorts.contains(port)) port = AbstractEngineCreator.getNewPort
    inInitPorts += port
    port
  }

  def removePort(port: Int): Unit = inInitPorts -= port

  protected def createProcessEngineBuilder(): ProcessEngineBuilder
  protected def getExtractSpringConfigs(requestEngine: RequestEngine): JMap[String, String] = {
    val springConf = new JMap[String, String]
    requestEngine.properties.keysIterator.filter(_.startsWith("spring.")).foreach(key => springConf.put(key.substring(7), requestEngine.properties.get(key)))
    springConf
  }
  protected def createEngine(processEngineBuilder:ProcessEngineBuilder,parser:DWCArgumentsParser):ProcessEngine={
     processEngineBuilder.getEngineResource match {
      case timeout: UserTimeoutEngineResource =>
        new CommonProcessEngine(processEngineBuilder, parser, timeout.getTimeout)
      case _ =>
        new CommonProcessEngine(processEngineBuilder, parser)
    }
  }

  override def create(ticketId: String, engineRequest: EngineResource, request: RequestEngine): Engine = {
    val port = getAvailablePort
    val processEngineBuilder = createProcessEngineBuilder()
    processEngineBuilder.setPort(port)
    processEngineBuilder.build(engineRequest, request)
    val parser = new DWCArgumentsParser
    var springConf = Map("spring.application.name" -> EngineManagerConfiguration.ENGINE_SPRING_APPLICATION_NAME.getValue,
      "server.port" -> port.toString, "spring.profiles.active" -> "engine",
      "logging.config" -> "classpath:log4j2-engine.xml",
      "eureka.client.serviceUrl.defaultZone" -> EngineManagerReceiver.getSpringConf("eureka.client.serviceUrl.defaultZone"))
    springConf = springConf ++: getExtractSpringConfigs(request).toMap
    parser.setSpringConf(springConf)
    var dwcConf = Map("ticketId" -> ticketId, "creator" -> request.creator, "user" -> request.user) ++:
      EngineCallback.callbackToMap(EngineCallback(Sender.getThisServiceInstance.getApplicationName, Sender.getThisServiceInstance.getInstance))
    if(request.properties.exists{case (k, v) => k.contains(" ") || (v != null && v.contains(" "))})
      throw new EngineManagerErrorException(30000, "Startup parameters contain spaces!(启动参数中包含空格!)")
    dwcConf = dwcConf ++: request.properties.toMap
    parser.setDWCConf(dwcConf)
    val engine = createEngine(processEngineBuilder,parser)
    engine.setTicketId(ticketId)
    engine.setPort(port)
    engine match {
      case commonEngine: CommonProcessEngine => commonEngine.setUser(request.user)
      case _ =>
    }
    engine
  }
}
object AbstractEngineCreator {
  private[enginemanager] def getNewPort: Int = {
    val socket = new ServerSocket(0)
    Utils.tryFinally(socket.getLocalPort)(IOUtils.closeQuietly(socket))
  }
} 
Example 87
Source File: ScalaDDLCreator.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.metadata.ddl

import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.metadata.conf.MdqConfiguration
import com.webank.wedatasphere.linkis.metadata.domain.mdq.bo.{MdqTableBO, MdqTableFieldsInfoBO}
import com.webank.wedatasphere.linkis.metadata.exception.MdqIllegalParamException
import org.apache.commons.lang.StringUtils

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

object ScalaDDLCreator extends DDLCreator with SQLConst with Logging{



  override def createDDL(tableInfo:MdqTableBO, user:String): String = {
    logger.info(s"begin to generate ddl for user $user using ScalaDDLCreator")
    val dbName = tableInfo.getTableBaseInfo.getBase.getDatabase
    val tableName = tableInfo.getTableBaseInfo.getBase.getName
    val fields = tableInfo.getTableFieldsInfo
    val createTableCode = new StringBuilder
    createTableCode.append(SPARK_SQL).append(LEFT_PARENTHESES).append(MARKS).append(CREATE_TABLE)
    createTableCode.append(dbName).append(".").append(tableName)
    createTableCode.append(LEFT_PARENTHESES)
    val partitions = new ArrayBuffer[MdqTableFieldsInfoBO]()
    val fieldsArray = new ArrayBuffer[String]()
    fields foreach {
      field =>
        if (field.getPartitionField != null && field.getPartitionField == true) partitions += field else{
          val name = field.getName
          val _type = field.getType
          val desc = field.getComment
          if (StringUtils.isNotEmpty(desc)){
            fieldsArray += (name + SPACE + _type + SPACE + COMMENT + SPACE + SINGLE_MARK + desc + SINGLE_MARK)
          }else{
            fieldsArray += (name + SPACE + _type)
          }
        }
    }
    createTableCode.append(fieldsArray.mkString(COMMA)).append(RIGHT_PARENTHESES).append(SPACE)
    if (partitions.nonEmpty){
      val partitionArr = new ArrayBuffer[String]()
      partitions foreach {
        p => val name = p.getName
          val _type = p.getType
          if (StringUtils.isEmpty(name) || StringUtils.isEmpty(_type)) throw MdqIllegalParamException("partition name or type is null")
          partitionArr += (name + SPACE + _type)
      }
      createTableCode.append(PARTITIONED_BY).append(LEFT_PARENTHESES).append(partitionArr.mkString(COMMA)).
        append(RIGHT_PARENTHESES).append(SPACE)
    }
    //如果是分区表,但是没有分区字段,默认是用ds做分区
    if(partitions.isEmpty && tableInfo.getTableBaseInfo.getBase.getPartitionTable){
      val partition = MdqConfiguration.DEFAULT_PARTITION_NAME.getValue
      val _type = "string"
      createTableCode.append(PARTITIONED_BY).append(LEFT_PARENTHESES).append(partition).append(SPACE).append(_type).
        append(RIGHT_PARENTHESES).append(SPACE)
    }
    createTableCode.append(STORED_AS).append(SPACE).append(MdqConfiguration.DEFAULT_STORED_TYPE.getValue).append(SPACE)
    createTableCode.append(MARKS)
    createTableCode.append(RIGHT_PARENTHESES)
    val finalCode = createTableCode.toString()
    logger.info(s"End to create ddl code, code is $finalCode")
    finalCode
  }

  def main(args: Array[String]): Unit = {
    val filePath = "E:\\data\\json\\data.json"
    val json = scala.io.Source.fromFile(filePath).mkString
    println(json)

   // val obj = new Gson().fromJson(json, classOf[MdqTableVO])
    //val sql = createDDL(obj, "hadoop")
    //println(System.currentTimeMillis())
    //println(sql)
  }


} 
Example 88
Source File: RMEventConsumer.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.resourcemanager.schedule

import java.util.concurrent.{ExecutorService, Future}

import com.webank.wedatasphere.linkis.common.utils.Utils
import com.webank.wedatasphere.linkis.resourcemanager.event.RMEvent
import com.webank.wedatasphere.linkis.resourcemanager.event.metric.{MetricRMEvent, MetricRMEventExecutor}
import com.webank.wedatasphere.linkis.resourcemanager.event.notify.{NotifyRMEvent, NotifyRMEventExecutor}
import com.webank.wedatasphere.linkis.scheduler.SchedulerContext
import com.webank.wedatasphere.linkis.scheduler.queue._

import scala.collection.mutable.ArrayBuffer


class RMEventConsumer(schedulerContext: SchedulerContext,
                      executeService: ExecutorService) extends Consumer(schedulerContext, executeService) {
  private var queue: ConsumeQueue = _
  private var group: Group = _
  private var maxRunningJobsNum = 1000
  //Not put(暂未放)
  private val runningJobs = new Array[SchedulerEvent](maxRunningJobsNum)
  private val executorManager = schedulerContext.getOrCreateExecutorManager
  private var rmConsumerListener : RMConsumerListener = _
  var future: Future[_] = _

  def this(schedulerContext: SchedulerContext, executeService: ExecutorService, group: Group) = {
    this(schedulerContext, executeService)
    this.group = group
    maxRunningJobsNum = group.getMaximumCapacity
  }

  def start():Unit = future = executeService.submit(this)

  def setRmConsumerListener(rmConsumerListener: RMConsumerListener): Unit ={
    this.rmConsumerListener = rmConsumerListener
  }

  override def setConsumeQueue(consumeQueue: ConsumeQueue) = {
    queue = consumeQueue
  }

  override def getConsumeQueue = queue

  override def getGroup = group

  override def setGroup(group: Group) = {
    this.group = group
  }

  override def getRunningEvents = getEvents(_.isRunning)

  private def getEvents(op: SchedulerEvent => Boolean): Array[SchedulerEvent] = {
    val result = ArrayBuffer[SchedulerEvent]()
    runningJobs.filter(_ != null).filter(x => op(x)).foreach(result += _)
    result.toArray
  }

  override def run() = {
    Thread.currentThread().setName(s"${toString}Thread")
    info(s"$toString thread started!")
    while (!terminate) {
      Utils.tryAndError(loop())
      Utils.tryQuietly(Thread.sleep(10))
    }
    info(s"$toString thread stopped!")
  }

  def loop(): Unit = {
    var event = queue.take()
    while (event.turnToScheduled() != true) {
      event = queue.take()
    }
    if(rmConsumerListener != null){rmConsumerListener.beforeEventExecute(this,event.asInstanceOf[RMEvent])}
    Utils.tryAndError({
      val executor = executorManager.askExecutor(event)
      if (executor.isDefined) {
        event match {
          case x: MetricRMEvent =>{
            Utils.tryQuietly(executor.get.asInstanceOf[MetricRMEventExecutor].execute(new EventJob(x)))
          }
          case y: NotifyRMEvent =>{
            Utils.tryQuietly(executor.get.asInstanceOf[NotifyRMEventExecutor].execute(new EventJob(y)))
          }
        }
      }
    })
    if(rmConsumerListener != null){rmConsumerListener.afterEventExecute(this,event.asInstanceOf[RMEvent])}
  }

  override def shutdown() = {
    future.cancel(true)
    super.shutdown()
  }
} 
Example 89
Source File: StorageScriptFsReader.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.storage.script.reader

import java.io._

import com.webank.wedatasphere.linkis.common.io.{FsPath, MetaData, Record}
import com.webank.wedatasphere.linkis.storage.script._
import com.webank.wedatasphere.linkis.storage.utils.StorageUtils
import org.apache.commons.io.IOUtils

import scala.collection.mutable.ArrayBuffer



  def isMetadata(line: String, prefix: String, prefixConf: String): Boolean = {
    val regex = ("\\s*" + prefix + "\\s*(.+)\\s*" + "=" + "\\s*(.+)\\s*").r
    line match {
      case regex(_, _) => true
      case _ => {
        val split: Array[String] = line.split("=")
        if (split.size != 2) return false
        if (split(0).split(" ").filter(_ != "").size != 4) return false
        if (!split(0).split(" ").filter(_ != "")(0).equals(prefixConf)) return false
        true
      }
    }
  }
} 
Example 90
Source File: ResultSetWriter.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.storage.resultset

import com.webank.wedatasphere.linkis.common.io.resultset.{ResultSet, ResultSetWriter}
import com.webank.wedatasphere.linkis.common.io.{FsPath, MetaData, Record}

import scala.collection.mutable.ArrayBuffer


object ResultSetWriter {
  def getResultSetWriter[K <: MetaData, V <: Record](resultSet: ResultSet[K,V], maxCacheSize: Long, storePath: FsPath):ResultSetWriter[K, V] =
    new StorageResultSetWriter[K, V](resultSet, maxCacheSize, storePath)

  def getResultSetWriter[K <: MetaData, V <: Record](resultSet: ResultSet[K,V], maxCacheSize: Long, storePath: FsPath, proxyUser:String):ResultSetWriter[K, V] ={
    val writer = new StorageResultSetWriter[K, V](resultSet, maxCacheSize, storePath)
    writer.setProxyUser(proxyUser)
    writer
  }


  def getRecordByWriter(writer: ResultSetWriter[_ <:MetaData,_ <:Record],limit:Long): Array[Record] ={
    val res = writer.toString
    getRecordByRes(res,limit)
  }

  def getRecordByRes(res: String,limit:Long): Array[Record] ={
    val reader = ResultSetReader.getResultSetReader(res)
    var count = 0
    val records = new ArrayBuffer[Record]()
    reader.getMetaData
    while (reader.hasNext && count < limit){
      records += reader.getRecord
      count = count + 1
    }
    records.toArray
  }

  def getLastRecordByRes(res: String):Record = {
    val reader = ResultSetReader.getResultSetReader(res)
    reader.getMetaData
    while (reader.hasNext ){
     reader.getRecord
    }
    reader.getRecord
  }
} 
Example 91
Source File: StorageResultSetReader.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.storage.resultset

import java.io.{ByteArrayInputStream, IOException, InputStream}

import com.webank.wedatasphere.linkis.common.io.resultset.{ResultSet, ResultSetReader}
import com.webank.wedatasphere.linkis.common.io.{MetaData, Record}
import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.storage.domain.Dolphin
import com.webank.wedatasphere.linkis.storage.exception.StorageWarnException
import com.webank.wedatasphere.linkis.storage.utils.StorageUtils

import scala.collection.mutable.ArrayBuffer



  def readLine(): Array[Byte] = {

    var rowLen = 0
    try rowLen = Dolphin.readInt(inputStream)
    catch {
      case t:StorageWarnException => info(s"Read finished(读取完毕)") ; return null
      case t: Throwable => throw t
    }

    val rowBuffer = ArrayBuffer[Byte]()
    var len = 0

    //Read the entire line, except for the data of the line length(读取整行,除了行长的数据)
    while (rowLen > 0 && len >= 0) {
      if (rowLen > READ_CACHE)
        len = StorageUtils.readBytes(inputStream,bytes, READ_CACHE)
      else
        len = StorageUtils.readBytes(inputStream,bytes, rowLen)

      if (len > 0) {
        rowLen -= len
        rowBuffer ++= bytes.slice(0, len)
      }
    }
    rowCount = rowCount + 1
    rowBuffer.toArray
  }

  @scala.throws[IOException]
  override def getRecord: Record = {
    if (metaData == null) throw new IOException("Must read metadata first(必须先读取metadata)")
    if (row ==  null) throw new IOException("Can't get the value of the field, maybe the IO stream has been read or has been closed!(拿不到字段的值,也许IO流已读取完毕或已被关闭!)")
    row
  }

  @scala.throws[IOException]
  override def getMetaData: MetaData = {
    if(metaData == null) init()
    metaData = deserializer.createMetaData(readLine())
    metaData
  }

  @scala.throws[IOException]
  override def skip(recordNum: Int): Int = {
    if(recordNum < 0 ) return -1

    if(metaData == null) getMetaData
    for(i <- recordNum until (0, -1)){
      try inputStream.skip(Dolphin.readInt(inputStream)) catch { case t: Throwable => return -1}
    }
    recordNum
  }

  @scala.throws[IOException]
  override def getPosition: Long = rowCount

  @scala.throws[IOException]
  override def hasNext: Boolean = {
    if(metaData == null) getMetaData
    val line = readLine()
    if(line == null) return  false
    row = deserializer.createRecord(line)
    if(row == null) return  false
    true
  }

  @scala.throws[IOException]
  override def available: Long = inputStream.available()

  override def close(): Unit = inputStream.close()
} 
Example 92
Source File: TableResultDeserializer.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.storage.resultset.table

import com.webank.wedatasphere.linkis.common.io.resultset.ResultDeserializer
import com.webank.wedatasphere.linkis.storage.domain.{Column, DataType, Dolphin}
import com.webank.wedatasphere.linkis.storage.exception.StorageErrorException

import scala.collection.mutable.ArrayBuffer


  override def createRecord(bytes: Array[Byte]): TableRecord = {
    val colByteLen = Dolphin.getString(bytes, 0, Dolphin.INT_LEN).toInt
    val colString = Dolphin.getString(bytes, Dolphin.INT_LEN, colByteLen)
    val colArray = if(colString.endsWith(Dolphin.COL_SPLIT)) colString.substring(0, colString.length -1).split(Dolphin.COL_SPLIT) else colString.split(Dolphin.COL_SPLIT)
    var index = Dolphin.INT_LEN + colByteLen
    val data = colArray.indices.map { i =>
      val len = colArray(i).toInt
      val res = Dolphin.getString(bytes, index, len)
      index += len
      if(i >= metaData.columns.length) res
      else
        toValue(metaData.columns(i).dataType,res)
    }.toArray
    new TableRecord(data)
  }
} 
Example 93
Source File: RetryHandler.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.common.utils

import com.webank.wedatasphere.linkis.common.exception.{DWCRetryException, FatalException}
import org.apache.commons.lang.{ClassUtils => CommonClassUtils}

import scala.collection.mutable.ArrayBuffer


trait RetryHandler extends Logging {

  private var retryNum = 2
  private var period = 100l
  private var maxPeriod = 1000l
  private val retryExceptions = ArrayBuffer[Class[_  <: Throwable]]()

  def setRetryNum(retryNum: Int): Unit = this.retryNum = retryNum
  def getRetryNum: Int = retryNum
  def setRetryPeriod(retryPeriod: Long): Unit = this.period = retryPeriod
  def getRetryPeriod: Long = period
  def setRetryMaxPeriod(retryMaxPeriod: Long): Unit = this.maxPeriod = retryMaxPeriod
  def getRetryMaxPeriod: Long = maxPeriod
  def addRetryException(t: Class[_  <: Throwable]): Unit = retryExceptions += t
  def getRetryExceptions = retryExceptions.toArray

  def exceptionCanRetry(t: Throwable): Boolean = !t.isInstanceOf[FatalException] &&
    retryExceptions.exists(c => CommonClassUtils.isAssignable(t.getClass, c))

  def nextInterval(attempt: Int): Long = {
    val interval = (this.period.toDouble * Math.pow(1.5D, (attempt - 1).toDouble)).toLong
    if (interval > this.maxPeriod) this.maxPeriod
    else interval
  }

  def retry[T](op: => T, retryName: String): T = {
    if(retryExceptions.isEmpty || retryNum <= 1) return op
    var retry = 0
    var result = null.asInstanceOf[T]
    while(retry < retryNum && result == null) result = Utils.tryCatch(op) { t =>
      retry += 1
      if(retry >= retryNum) throw t
      else if(exceptionCanRetry(t)) {
        val retryInterval = nextInterval(retry)
        info(retryName + s" failed with ${t.getClass.getName}, wait ${ByteTimeUtils.msDurationToString(retryInterval)} for next retry. Retried $retry++ ...")
        Utils.tryQuietly(Thread.sleep(retryInterval))
        null.asInstanceOf[T]
      } else throw t
    }
    result
  }
} 
Example 94
Source File: ShutdownUtils.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.common.utils

import sun.misc.{Signal, SignalHandler}

import scala.collection.mutable.ArrayBuffer


object ShutdownUtils {

  private val shutdownRunners = ArrayBuffer[ShutdownRunner]()

  def addShutdownHook(runnable: Runnable): Unit = addShutdownHook(Int.MaxValue, runnable)

  def addShutdownHook(order: Int, runnable: Runnable): Unit =
    shutdownRunners synchronized shutdownRunners += new DefaultShutdownRunner(order, runnable)

  def addShutdownHook(hook: => Unit): Unit = addShutdownHook(Int.MaxValue, hook)

  def addShutdownHook(order: Int, hook: => Unit): Unit =
    shutdownRunners synchronized shutdownRunners += new FunctionShutdownRunner(order, hook)

  def addShutdownHook(shutdownRunner: ShutdownRunner): Unit =
    shutdownRunners synchronized shutdownRunners += shutdownRunner
  private val signals = Array("TERM", "HUP", "INT").map(new Signal(_))
  private val signalHandler = new SignalHandler {
    override def handle(signal: Signal): Unit = {
      val hooks = shutdownRunners.sortBy(_.order).toArray.map{
        case m: DefaultShutdownRunner =>
          Utils.defaultScheduler.execute(m)
          m
        case m =>
          val runnable = new DefaultShutdownRunner(m.order, m)
          Utils.defaultScheduler.execute(runnable)
          runnable
      }
      val startTime = System.currentTimeMillis
      ShutdownUtils synchronized {
        while(System.currentTimeMillis - startTime < 30000 && hooks.exists(!_.isCompleted))
          ShutdownUtils.wait(3000)
      }
      System.exit(0)
    }
  }
  signals.foreach(Signal.handle(_, signalHandler))
}
trait ShutdownRunner extends Runnable {
  val order: Int
}
class DefaultShutdownRunner(override val order: Int,
                            runnable: Runnable) extends ShutdownRunner {
  private var completed = false
  override def run(): Unit = Utils.tryFinally(runnable.run()){
    completed = true
    ShutdownUtils synchronized ShutdownUtils.notify()
  }
  def isCompleted = completed
}
class FunctionShutdownRunner(override val order: Int,
                             hook: => Unit) extends ShutdownRunner {
  override def run(): Unit = hook
} 
Example 95
Source File: DWCArgumentsParser.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.common.conf

import org.apache.commons.lang.StringUtils

import scala.collection.{JavaConversions, mutable}
import scala.collection.mutable.ArrayBuffer


object DWCArgumentsParser {
  protected val DWC_CONF = "--dwc-conf"
  protected val SPRING_CONF = "--spring-conf"
  private var dwcOptionMap = Map.empty[String, String]

  private[linkis] def setDWCOptionMap(dwcOptionMap: Map[String, String]) = this.dwcOptionMap = dwcOptionMap
  def getDWCOptionMap = dwcOptionMap

  def parse(args: Array[String]): DWCArgumentsParser = {
    val keyValueRegex = "([^=]+)=(.+)".r
    var i = 0
    val optionParser = new DWCArgumentsParser
    while(i < args.length) {
      args(i) match {
        case DWC_CONF | SPRING_CONF =>
          args(i + 1) match {
            case keyValueRegex(key, value) =>
              optionParser.setConf(args(i), key, value)
              i += 1
            case _ => throw new IllegalArgumentException("illegal commond line, format: --conf key=value.")
          }
        case _ => throw new IllegalArgumentException(s"illegal commond line, ${args(i)} cannot recognize.")
      }
      i += 1
    }
    optionParser.validate()
    optionParser
  }

  def formatToArray(optionParser: DWCArgumentsParser): Array[String] = {
    val options = ArrayBuffer[String]()
    def write(confMap: Map[String, String], optionType: String): Unit = confMap.foreach { case (key, value) =>
      if (StringUtils.isNotEmpty(key) && StringUtils.isNotEmpty(value)) {
        options += optionType
        options += (key + "=" + value)
      }
    }
    write(optionParser.getDWCConfMap, DWC_CONF)
    write(optionParser.getSpringConfMap, SPRING_CONF)
    options.toArray
  }
  def formatToArray(springOptionMap: Map[String, String], dwcOptionMap: Map[String, String]): Array[String] =
    formatToArray(new DWCArgumentsParser().setSpringConf(springOptionMap).setDWCConf(dwcOptionMap))

  def format(optionParser: DWCArgumentsParser): String = formatToArray(optionParser).mkString(" ")
  def format(springOptionMap: Map[String, String], dwcOptionMap: Map[String, String]): String =
    formatToArray(springOptionMap, dwcOptionMap).mkString(" ")

  def formatSpringOptions(springOptionMap: Map[String, String]): Array[String] = {
    val options = ArrayBuffer[String]()
    springOptionMap.foreach { case (key, value) =>
      if (StringUtils.isNotEmpty(key) && StringUtils.isNotEmpty(value)) {
        options += ("--" + key + "=" + value)
      }
    }
    options.toArray
  }
}
class DWCArgumentsParser {
  import DWCArgumentsParser._
  private val dwcOptionMap = new mutable.HashMap[String, String]()
  private val springOptionMap = new mutable.HashMap[String, String]()
  def getSpringConfMap = springOptionMap.toMap
  def getSpringConfs = JavaConversions.mapAsJavaMap(springOptionMap)
  def getDWCConfMap = dwcOptionMap.toMap
  def setConf(optionType: String, key: String, value: String) = {
    optionType match {
      case DWC_CONF =>
        dwcOptionMap += key -> value
      case SPRING_CONF =>
        springOptionMap += key -> value
    }
    this
  }
  def setSpringConf(optionMap: Map[String, String]): DWCArgumentsParser = {
    if(optionMap != null) this.springOptionMap ++= optionMap
    this
  }
  def setDWCConf(optionMap: Map[String, String]): DWCArgumentsParser = {
    if(optionMap != null) this.dwcOptionMap ++= optionMap
    this
  }
  def validate() = {}
} 
Example 96
Source File: _03_TraitsAsStackableModifications.scala    From LearningScala   with Apache License 2.0 5 votes vote down vote up
package _033_traits

import scala.collection.mutable.ArrayBuffer


  class MyQueue extends BasicIntQueue with Doubling

  def main(args: Array[String]): Unit = {
    val queue = new BasicIntQueue
    queue.put(-10)
    queue.put(20)
    println(s"queue.get(): ${queue.get()}")
    println(s"queue.get(): ${queue.get()}")
    println()

    val myQueue = new MyQueue
    myQueue.put(-10)
    myQueue.put(20)
    println(s"myQueue.get(): ${myQueue.get()}")
    println(s"myQueue.get(): ${myQueue.get()}")
    println()

    // You could supply "BasicIntQueue with Doubling" directly to new instead of defining a named class.
    val queueWithDoubling = new BasicIntQueue with Doubling
    queueWithDoubling.put(-10)
    queueWithDoubling.put(20)
    println(s"queueWithDoubling.get(): ${queueWithDoubling.get()}")
    println(s"queueWithDoubling.get(): ${queueWithDoubling.get()}")
    println()


    // ORDER MATTERS examples:
    // You can now pick and choose which traits you want for a particular queue.
    val q1 = new BasicIntQueue with Incrementing with Filtering
    q1.put(-1)
    q1.put(0)
    q1.put(1)
    println(s"q1.get(): ${q1.get()}")
    println(s"q1.get(): ${q1.get()}")
    //    println(s"q1.get(): ${q1.get()}") // will give an error
    println()

    val q2 = new BasicIntQueue with Filtering with Incrementing
    q2.put(-1)
    q2.put(0)
    q2.put(1)
    println(s"q2.get(): ${q2.get()}")
    println(s"q2.get(): ${q2.get()}")
    println(s"q2.get(): ${q2.get()}")
    println()
  }
} 
Example 97
Source File: _10_MutableCollections.scala    From LearningScala   with Apache License 2.0 5 votes vote down vote up
package _020_collections


object _10_MutableCollections {
  def main(args: Array[String]): Unit = {
    println("===== List buffers =====")
    listBufferExample()
    println()

    println("===== Array buffers =====")
    println(arrayBufferExample())
    println()

    println("===== Mutable Sets =====")
    mutableSetExample()
    println()

    println("===== Mutable Maps =====")
    mutableMapExample()
  }

  private def mutableMapExample(): Unit = {
    import scala.collection.mutable
    val map = mutable.Map.empty[String, Int]
    println(map)
    map("hello") = 1
    map("there") = 2
    println(map)
    println(map("hello"))
    println("======")
    val nums = mutable.Map("i" -> 1, "ii" -> 2)
    println(nums)
    nums += ("vi" -> 6)
    println(nums)
    nums -= "ii"
    println(nums)
    nums ++= List("iii" -> 3, "v" -> 5)
    println(nums)
    nums --= List("i", "ii")
    println(nums)
    println("=====")
    println(s"nums.size: ${nums.size}")
    print("nums.contains(\"ii\"): ")
    println(nums.contains("ii"))
    print("nums(\"iii\"): ")
    println(nums("iii"))
    println(s"nums.keys ==> ${nums.keys}")
    println(s"nums.keySet ==> ${nums.keySet}")
    println(s"nums.values ==> ${nums.values}")
    println(s"nums.isEmpty: ${nums.isEmpty}")
  }

  def arrayBufferExample(): List[Int] = {
    import scala.collection.mutable.ArrayBuffer
    val ab = ArrayBuffer[Int](10, 20)
    ab += 30
    ab += 40
    ab.prepend(5)
    ab.toList //return immutable
  }

  private def listBufferExample(): Unit = {
    import scala.collection.mutable.ListBuffer
    val listBuffer = new ListBuffer[Int]
    listBuffer += 1
    listBuffer += 2
    println(listBuffer)
    3 +=: listBuffer
    println(listBuffer)
    val list = listBuffer.toList
    println(list)
  }

  private def mutableSetExample(): Unit = {
    import scala.collection.mutable
    val emptySet = mutable.Set.empty[Int]
    println(emptySet)
    val nums = mutable.Set(1, 2, 3)
    println(nums)
    nums += 5
    println(nums)
    nums -= 3
    println(nums)
    nums ++= List(5, 6)
    println(nums)
    nums --= List(1, 2)
    println(nums)
    println(nums & Set(1, 3, 5, 7)) // intersection of two sets
    nums.clear()
    println(nums)
  }
} 
Example 98
Source File: TestableQueueInputDStream.scala    From SparkUnitTestingExamples   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.{ObjectInputStream, ObjectOutputStream}

import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming.dstream.InputDStream

import scala.collection.mutable.{ArrayBuffer, Queue}
import scala.reflect.ClassTag

class TestableQueueInputDStream[T: ClassTag](
                                              ssc: StreamingContext,
                                              val queue: Queue[RDD[T]],
                                              oneAtATime: Boolean,
                                              defaultRDD: RDD[T]
                                              ) extends InputDStream[T](ssc) {

  override def start() { }

  override def stop() { }

  private def readObject(in: ObjectInputStream): Unit = {
    logWarning("queueStream doesn't support checkpointing")
  }

  private def writeObject(oos: ObjectOutputStream): Unit = {
    logWarning("queueStream doesn't support checkpointing")
  }

  override def compute(validTime: Time): Option[RDD[T]] = {
    val buffer = new ArrayBuffer[RDD[T]]()
    queue.synchronized {
      if (oneAtATime && queue.nonEmpty) {
        buffer += queue.dequeue()
      } else {
        buffer ++= queue
        queue.clear()
      }
    }
    if (buffer.nonEmpty) {
      if (oneAtATime) {
        Some(buffer.head)
      } else {
        Some(new UnionRDD(context.sc, buffer.toSeq))
      }
    } else if (defaultRDD != null) {
      Some(defaultRDD)
    } else {
      Some(ssc.sparkContext.emptyRDD)
    }
  }

} 
Example 99
Source File: PruneWorker.scala    From spatial   with MIT License 5 votes vote down vote up
package spatial.dse

import java.util.concurrent.LinkedBlockingQueue

import argon.State
import spatial.metadata.params._
import spatial.metadata.bounds._

import scala.collection.mutable.ArrayBuffer

case class PruneWorker(
  start: Int,
  size: Int,
  prods: Seq[BigInt],
  dims:  Seq[BigInt],
  indexedSpace: Seq[(Domain[_],Int)],
  restricts: Set[Restrict],
  queue: LinkedBlockingQueue[Seq[Int]]
)(implicit state: State) extends Runnable {

  private def isLegalSpace(): Boolean = restricts.forall(_.evaluate())

  def run(): Unit = {
    println(s"Searching from $start until ${start+size}")
    val pts = (start until (start+size)).filter{i =>
      indexedSpace.foreach{case (domain,d) => domain.set( ((i / prods(d)) % dims(d)).toInt ) }
      isLegalSpace()
    }
    queue.put(pts)
  }
} 
Example 100
Source File: Flows.scala    From spatial   with MIT License 5 votes vote down vote up
package argon

import scala.collection.mutable.{ArrayBuffer,HashSet}

import utils.Instrument

trait FlowRules {
  val IR: State

}


class Flows {
  private var rules = ArrayBuffer[(String,PartialFunction[(Sym[_],Op[_],SrcCtx,State),Unit])]()
  private[argon] var names = HashSet[String]()

  lazy val instrument = new Instrument("flows")

  def prepend(name: String, func: PartialFunction[(Sym[_],Op[_],SrcCtx,State),Unit]): Unit = {
    rules.prepend((name,func))
    names += name
  }

  def add(name: String, func: PartialFunction[(Sym[_],Op[_],SrcCtx,State),Unit]): Unit = {
    rules += ((name,func))
    names += name
  }
  def remove(name: String): Unit = {
    val idx = rules.indexWhere(_._1 == name)
    rules.remove(idx)
    names.remove(name)
  }

  def apply[A](lhs: Sym[A], rhs: Op[A])(implicit ctx: SrcCtx, state: State): Unit = {
    val tuple = (lhs,rhs,ctx,state)
    rules.foreach{case (name,rule) =>
      if (rule.isDefinedAt(tuple)) { instrument(name){ rule.apply(tuple) } }
    }
  }

  def save(): Flows = {
    val flows = new Flows
    flows.rules ++= rules
    flows.names ++= names
    flows
  }
  def restore(flow: Flows): Unit = {
    rules = flow.rules
    names = flow.names
  }
} 
Example 101
Source File: Rewrites.scala    From spatial   with MIT License 5 votes vote down vote up
package argon

import utils.implicits.collections._

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

trait RewriteRules {
  val IR: State
}


class Rewrites {
  type RewriteRule = PartialFunction[(Op[_],SrcCtx,State),Option[Sym[_]]]

  private def keyOf[A<:Op[_]:Manifest] = manifest[A].runtimeClass.asInstanceOf[Class[A]]

  // Roughly O(G), where G is the total number of global rewrite rules
  // When possible, use rules instead of globals
  private var globals: ArrayBuffer[RewriteRule] = ArrayBuffer.empty

  // Roughly O(R), where R is the number of rules for a specific node class
  private val rules: mutable.HashMap[Class[_], ArrayBuffer[RewriteRule]] = mutable.HashMap.empty
  private[argon] val names: mutable.HashSet[String] = mutable.HashSet.empty

  def rule(op: Op[_]): Seq[RewriteRule] = rules.getOrElse(op.getClass, Nil)

  def addGlobal(name: String, rule: RewriteRule): Unit = if (!names.contains(name)) {
    names += name
    globals += rule
  }

  def add[O<:Op[_]:Manifest](name: String, rule: RewriteRule): Unit = if (!names.contains(name)) {
    names += name
    val key = keyOf[O]
    val pfs = rules.getOrElseAdd(key, () => ArrayBuffer.empty[RewriteRule])
    pfs += rule
  }

  private def applyRule[A:Type](op: Op[A], ctx: SrcCtx, state: State, rule: RewriteRule): Option[A] = {
    rule.apply((op,ctx,state)) match {
      case Some(s) if s.tp <:< Type[A] => Some(s.asInstanceOf[A])
      case Some(s) => None
      case _ => None
    }
  }

  def apply[A:Type](op: Op[A])(implicit ctx: SrcCtx, state: State): Option[A] = {
    Option(op.rewrite)
          .orElse{ rule(op).mapFind{rule => applyRule[A](op,ctx,state, rule) } }
          .orElse{ globals.mapFind{rule => applyRule[A](op,ctx,state, rule) } }.map { op2 =>
      if (state.config.enLog) {
        dbgs(s"Rewrite $op => $op2")
      }
      op2
    }
  }
} 
Example 102
Source File: BitTest.scala    From spatial   with MIT License 5 votes vote down vote up
package spatial.tests.compiler

import spatial.dsl._

import scala.collection.mutable.ArrayBuffer

@spatial class BitTest extends SpatialTest {
  override def backends = DISABLED

  // Returns a random number in [min,max)
  def rand(max: gen.Int, min: gen.Int): gen.Int = scala.util.Random.nextInt(max-min)+min

  def opp(x: Bit, y: Bit, op: gen.Int): Bit = op match {
    case 0 | 1 | 2 => x & y
    case 3 | 4 | 5 => x | y
    case 6 | 7 | 8 => x !== y
    case 9 | 10 | 11 => x === y
    case 12 => !x
    case 13 => !y
  }

  def main(args: Array[String]): Void = {
    Foreach(0 until 32){i =>
      val bits: List[Bit] = List.fill(32){ random[Bit] }
      var layers: ArrayBuffer[List[Bit]] = ArrayBuffer(bits)

      (0 until 64).meta.foreach{i =>
        val layer = List.fill(200){
          val l1 = i //rand(layers.length,0)
          val l2 = i //rand(layers.length,0)
          val p1 = rand(layers(l1).length, 0)
          val p2 = rand(layers(l2).length, 0)
          val op = rand(14,0)
          val x = layers(l1).apply(p1)
          val y = layers(l2).apply(p2)
          opp(x,y,op)
        }
        layers += layer

        println(r"[$i] 1: ${layer(1)}, 3: ${layer(3)}, 5: ${layer(5)}")
      }
    }
  }
} 
Example 103
Source File: TemplateRunner.scala    From spatial   with MIT License 5 votes vote down vote up
package fringe.test

import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.util.Properties.envOrElse

object TemplateRunner {
  def deleteRecursively(file: File): Unit = {
    if (file.isDirectory)
      file.listFiles.foreach(deleteRecursively)
    if (file.exists && !file.delete)
      throw new Exception(s"Unable to delete ${file.getAbsolutePath}")
  }
  def apply(templateMap: Map[String, String => Boolean], args: Array[String]): Unit = {
    // Choose the default backend based on what is available.
    lazy val firrtlTerpBackendAvailable: Boolean = {
      try {
        val cls = Class.forName("chisel3.iotesters.FirrtlTerpBackend")
        cls != null
      } catch {
        case e: Throwable => false
      }
    }
    lazy val defaultBackend = if (firrtlTerpBackendAvailable) "firrtl" else ""

    val backendName = envOrElse("TESTER_BACKENDS", defaultBackend).split(" ").head
    val tempDir = s"""${envOrElse("NEW_TEMPLATES_HOME", "tmp")}/test_run_dir/"""
    val specificRegex = "(.*[0-9]+)".r
    val problemsToRun = if (args.isEmpty) {
      templateMap.keys.toSeq.sorted.toArray // Run all by default
    } else {
      args.map { arg => arg match {
        case "all" => templateMap.keys.toSeq.sorted // Run all
        case specificRegex(c) => List(c).toSeq // Run specific test
        case _ => // Figure out tests that match this template and run all
          val tempRegex = s"(${arg}[0-9]+)".r
          templateMap.keys.toSeq.sorted.filter(tempRegex.pattern.matcher(_).matches)
      }}.flatten.toArray
    }

    var successful = 0
    var passedTests:List[String] = List()
    val errors = new ArrayBuffer[String]
    for(testName <- problemsToRun) {
      // Wipe tempdir for consecutive tests of same module
      deleteRecursively(new File(tempDir))
      templateMap.get(testName) match {
        case Some(test) =>
          println(s"Starting template $testName")
          try {
            if(test(backendName)) {
              successful += 1
              passedTests = passedTests :+ s"$testName"
            }
            else {
              errors += s"Template $testName: test error occurred"
            }
          }
          catch {
            case exception: Exception =>
              exception.printStackTrace()
              errors += s"Template $testName: exception ${exception.getMessage}"
            case t : Throwable =>
              errors += s"Template $testName: throwable ${t.getMessage}"
          }
        case _ =>
          errors += s"Bad template name: $testName"
      }
    }
    if(successful > 0) {
      println(s"""Templates passing: $successful (${passedTests.mkString(", ")})""")
    }
    if(errors.nonEmpty) {
      println("=" * 80)
      println(s"Errors: ${errors.length}: in the following templates")
      println(errors.mkString("\n"))
      println("=" * 80)
      System.exit(1)
    }
  }
} 
Example 104
Source File: AvroSchemaMerge.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.avro

import com.sksamuel.exts.StringOption
import org.apache.avro.Schema

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

object AvroSchemaMerge {

  def apply(name: String, namespace: String, schemas: List[Schema]): Schema = {
    require(schemas.forall(_.getType == Schema.Type.RECORD), "Can only merge records")

    // documentations can just be a concat
    val doc = schemas.map(_.getDoc).filter(_ != null).mkString("; ")

    // simple impl to start: take all the fields from the first schema, and then add in the missing ones
    // from second 2 and so on
    val fields = new ArrayBuffer[Schema.Field]()
    schemas.foreach { schema =>
      schema.getFields.asScala.filterNot { field => fields.exists(_.name() == field.name) }.foreach { field =>
        // avro is funny about sharing fields, so need to copy it
        val copy = new Schema.Field(field.name(), field.schema(), StringOption(field.doc).orNull, field.defaultVal)
        fields.append(copy)
      }
    }

    val schema = Schema.createRecord(name, if (doc.isEmpty()) null else doc, namespace, false)
    schema.setFields(fields.result().asJava)
    schema
  }
} 
Example 105
Source File: JdbcPublisher.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.concurrent.atomic.AtomicBoolean

import com.sksamuel.exts.io.Using
import com.sksamuel.exts.metrics.Timed
import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.{Publisher, Subscriber, Subscription}

import scala.collection.mutable.ArrayBuffer

class JdbcPublisher(connFn: () => Connection,
                    query: String,
                    bindFn: (PreparedStatement) => Unit,
                    fetchSize: Int,
                    dialect: JdbcDialect
              ) extends Publisher[Seq[Row]] with Timed with JdbcPrimitives with Using {

  override def subscribe(subscriber: Subscriber[Seq[Row]]): Unit = {
    try {
      using(connFn()) { conn =>

        logger.debug(s"Preparing query $query")
        using(conn.prepareStatement(query)) { stmt =>

          stmt.setFetchSize(fetchSize)
          bindFn(stmt)

          logger.debug(s"Executing query $query")
          using(stmt.executeQuery()) { rs =>

            val schema = schemaFor(dialect, rs)

            val running = new AtomicBoolean(true)
            subscriber.subscribed(Subscription.fromRunning(running))

            val buffer = new ArrayBuffer[Row](fetchSize)
            while (rs.next && running.get) {
              val values = schema.fieldNames().map { name =>
                val raw = rs.getObject(name)
                dialect.sanitize(raw)
              }
              buffer append Row(schema, values)
              if (buffer.size == fetchSize) {
                subscriber.next(buffer.toVector)
                buffer.clear()
              }
            }

            if (buffer.nonEmpty)
              subscriber.next(buffer.toVector)

            subscriber.completed()
          }
        }
      }
    } catch {
      case t: Throwable => subscriber.error(t)
    }
  }
} 
Example 106
Source File: HbasePublisher.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.hbase

import java.util
import java.util.concurrent.atomic.AtomicBoolean

import com.sksamuel.exts.io.Using
import com.sksamuel.exts.metrics.Timed
import io.eels.Row
import io.eels.datastream.{Publisher, Subscriber, Subscription}
import io.eels.schema.StructType
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.{Connection, Result, Scan}

import scala.collection.mutable.ArrayBuffer

class HbasePublisher(connection: Connection,
                     schema: StructType,
                     namespace: String,
                     tableName: String,
                     bufferSize: Int,
                     maxRows: Long,
                     scanner: Scan,
                     implicit val serializer: HbaseSerializer) extends Publisher[Seq[Row]] with Timed with Using {

  private val table = connection.getTable(TableName.valueOf(namespace, tableName))

  override def subscribe(subscriber: Subscriber[Seq[Row]]): Unit = {
    try {
      using(new CloseableIterator) { rowIter =>
        val running = new AtomicBoolean(true)
        subscriber.subscribed(Subscription.fromRunning(running))
        val buffer = new ArrayBuffer[Row](bufferSize)
        while (rowIter.hasNext && running.get()) {
          buffer append rowIter.next()
          if (buffer.size == bufferSize) {
            subscriber.next(buffer.toVector)
            buffer.clear()
          }
        }
        if (buffer.nonEmpty) subscriber.next(buffer.toVector)
        subscriber.completed()
      }
    } catch {
      case t: Throwable => subscriber.error(t)
    }
  }

  class CloseableIterator extends Iterator[Row] with AutoCloseable {
    private val resultScanner = table.getScanner(scanner)
    private val resultScannerIter = resultScanner.iterator()
    private var rowCount = 0
    private var iter: Iterator[Row] = Iterator.empty

    override def hasNext: Boolean = rowCount < maxRows && iter.hasNext || {
      if (rowCount < maxRows && resultScannerIter.hasNext) {
        iter = HBaseResultsIterator(schema, resultScannerIter)
        iter.hasNext
      } else false
    }

    override def next(): Row = {
      rowCount += 1
      iter.next()
    }

    override def close(): Unit = {
      resultScanner.close()
    }
  }

  case class HBaseResultsIterator(schema: StructType, resultIter: util.Iterator[Result])(implicit serializer: HbaseSerializer) extends Iterator[Row] {
    override def hasNext: Boolean = resultIter.hasNext

    override def next(): Row = {
      val resultRow = resultIter.next()
      val values = schema.fields.map { field =>
        if (!field.key) {
          val value = resultRow.getValue(field.columnFamily.getOrElse(sys.error(s"No Column Family defined for field '${field.name}'")).getBytes, field.name.getBytes)
          if (value != null) serializer.fromBytes(value, field.name, field.dataType) else null
        } else serializer.fromBytes(resultRow.getRow, field.name, field.dataType)
      }
      Row(schema, values)
    }
  }


} 
Example 107
Source File: OrcWriter.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.orc

import java.util.concurrent.atomic.AtomicInteger
import java.util.function.IntUnaryOperator

import com.sksamuel.exts.Logging
import com.typesafe.config.ConfigFactory
import io.eels.Row
import io.eels.schema.StructType
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.vector.ColumnVector
import org.apache.orc.{OrcConf, OrcFile, TypeDescription}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

// performs the actual write out of orc data, to be used by an orc sink
class OrcWriter(path: Path,
                structType: StructType,
                options: OrcWriteOptions)(implicit conf: Configuration) extends Logging {

  private val schema: TypeDescription = OrcSchemaFns.toOrcSchema(structType)
  logger.trace(s"Creating orc writer for schema $schema")

  private val batchSize = {
    val size = ConfigFactory.load().getInt("eel.orc.sink.batchSize")
    Math.max(Math.min(1024, size), 1)
  }
  logger.debug(s"Orc writer will use batchsize=$batchSize")

  private val buffer = new ArrayBuffer[Row](batchSize)
  private val serializers = schema.getChildren.asScala.map(OrcSerializer.forType).toArray
  private val batch = schema.createRowBatch(batchSize)

  OrcConf.COMPRESSION_STRATEGY.setString(conf, options.compressionStrategy.name)
  OrcConf.COMPRESS.setString(conf, options.compressionKind.name)
  options.encodingStrategy.map(_.name).foreach(OrcConf.ENCODING_STRATEGY.setString(conf, _))
  options.compressionBufferSize.foreach(OrcConf.BUFFER_SIZE.setLong(conf, _))
  private val woptions = OrcFile.writerOptions(conf).setSchema(schema)

  options.rowIndexStride.foreach { size =>
    woptions.rowIndexStride(size)
    logger.debug(s"Using stride size = $size")
  }

  if (options.bloomFilterColumns.nonEmpty) {
    woptions.bloomFilterColumns(options.bloomFilterColumns.mkString(","))
    logger.debug(s"Using bloomFilterColumns = $options.bloomFilterColumns")
  }
  private lazy val writer = OrcFile.createWriter(path, woptions)

  private val counter = new AtomicInteger(0)

  def write(row: Row): Unit = {
    buffer.append(row)
    if (buffer.size == batchSize)
      flush()
  }

  def records: Int = counter.get()

  def flush(): Unit = {

    def writecol[T <: ColumnVector](rowIndex: Int, colIndex: Int, row: Row): Unit = {
      val value = row.values(colIndex)
      val vector = batch.cols(colIndex).asInstanceOf[T]
      val serializer = serializers(colIndex).asInstanceOf[OrcSerializer[T]]
      serializer.writeToVector(rowIndex, vector, value)
    }

    // don't use foreach here, using old school for loops for perf
    for (rowIndex <- buffer.indices) {
      val row = buffer(rowIndex)
      for (colIndex <- batch.cols.indices) {
        writecol(rowIndex, colIndex, row)
      }
    }

    batch.size = buffer.size
    writer.addRowBatch(batch)
    counter.updateAndGet(new IntUnaryOperator {
      override def applyAsInt(operand: Int): Int = operand + batch.size
    })
    buffer.clear()
    batch.reset()
  }

  def close(): Long = {
    if (buffer.nonEmpty)
      flush()
    writer.close()
    val count = writer.getNumberOfRows
    logger.info(s"Orc writer wrote $count rows")
    count
  }
} 
Example 108
Source File: SKRSpec.scala    From spark-kafka-writer   with Apache License 2.0 5 votes vote down vote up
package com.github.benfradet.spark.kafka.writer

import java.util.concurrent.atomic.AtomicInteger

import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.kafka010.{ConsumerStrategies, KafkaUtils, LocationStrategies}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

case class Foo(a: Int, b: String)

trait SKRSpec
  extends AnyWordSpec
  with Matchers
  with BeforeAndAfterEach
  with BeforeAndAfterAll
  with Eventually {

  val sparkConf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(getClass.getSimpleName)

  var ktu: KafkaTestUtils = _
  override def beforeAll(): Unit = {
    ktu = new KafkaTestUtils
    ktu.setup()
  }
  override def afterAll(): Unit = {
    SKRSpec.callbackTriggerCount.set(0)
    if (ktu != null) {
      ktu.tearDown()
      ktu = null
    }
  }

  var topic: String = _
  var ssc: StreamingContext = _
  var spark: SparkSession = _
  override def afterEach(): Unit = {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (spark != null) {
      spark.stop()
      spark = null
    }
  }
  override def beforeEach(): Unit = {
    ssc = new StreamingContext(sparkConf, Seconds(1))
    spark = SparkSession.builder
      .config(sparkConf)
      .getOrCreate()
    topic = s"topic-${Random.nextInt()}"
    ktu.createTopics(topic)
  }

  def collect(ssc: StreamingContext, topic: String): ArrayBuffer[String] = {
    val kafkaParams = Map(
      "bootstrap.servers" -> ktu.brokerAddress,
      "auto.offset.reset" -> "earliest",
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "test-collect"
    )
    val results = new ArrayBuffer[String]
    KafkaUtils.createDirectStream[String, String](
      ssc,
      LocationStrategies.PreferConsistent,
      ConsumerStrategies.Subscribe[String, String](Set(topic), kafkaParams)
    ).map(_.value())
      .foreachRDD { rdd =>
        results ++= rdd.collect()
        ()
      }
    results
  }

  val producerConfig = Map(
    "bootstrap.servers" -> "127.0.0.1:9092",
    "key.serializer" -> classOf[StringSerializer].getName,
    "value.serializer" -> classOf[StringSerializer].getName
  )
}

object SKRSpec {
  val callbackTriggerCount = new AtomicInteger()
} 
Example 109
Source File: SidechainBlockInfo.scala    From Sidechains-SDK   with MIT License 5 votes vote down vote up
package com.horizen.chain

import com.horizen.block.SidechainBlock
import com.horizen.utils.{WithdrawalEpochInfo, WithdrawalEpochInfoSerializer}
import com.horizen.vrf.{VrfOutput, VrfOutputSerializer}
import scorex.core.NodeViewModifier
import scorex.core.block.Block.Timestamp
import scorex.core.consensus.ModifierSemanticValidity
import scorex.core.serialization.{BytesSerializable, ScorexSerializer}
import scorex.util.serialization.{Reader, Writer}
import scorex.util.{ModifierId, bytesToId, idToBytes}

import scala.collection.mutable.ArrayBuffer

case class SidechainBlockInfo(height: Int,
                              score: Long,
                              parentId: ModifierId,
                              timestamp: Timestamp,
                              semanticValidity: ModifierSemanticValidity,
                              mainchainHeaderHashes: Seq[MainchainHeaderHash],
                              mainchainReferenceDataHeaderHashes: Seq[MainchainHeaderHash],
                              withdrawalEpochInfo: WithdrawalEpochInfo,
                              vrfOutputOpt: Option[VrfOutput],
                              lastBlockInPreviousConsensusEpoch: ModifierId) extends BytesSerializable with LinkedElement[ModifierId] {

  override def getParentId: ModifierId = parentId

  override type M = SidechainBlockInfo

  override lazy val serializer: ScorexSerializer[SidechainBlockInfo] = SidechainBlockInfoSerializer

  override def bytes: Array[Byte] = SidechainBlockInfoSerializer.toBytes(this)
}

object SidechainBlockInfo {
  def mainchainHeaderHashesFromBlock(sidechainBlock: SidechainBlock): Seq[MainchainHeaderHash] = {
    sidechainBlock.mainchainHeaders.map(header => byteArrayToMainchainHeaderHash(header.hash))
  }

  def mainchainReferenceDataHeaderHashesFromBlock(sidechainBlock: SidechainBlock): Seq[MainchainHeaderHash] = {
    sidechainBlock.mainchainBlockReferencesData.map(data => byteArrayToMainchainHeaderHash(data.headerHash))
  }
}

object SidechainBlockInfoSerializer extends ScorexSerializer[SidechainBlockInfo] {
  override def serialize(obj: SidechainBlockInfo, w: Writer): Unit = {
    w.putInt(obj.height)
    w.putLong(obj.score)
    w.putBytes(idToBytes(obj.parentId))
    w.putLong(obj.timestamp)
    w.put(obj.semanticValidity.code)
    w.putInt(obj.mainchainHeaderHashes.size)
    obj.mainchainHeaderHashes.foreach(id => w.putBytes(id.data))
    w.putInt(obj.mainchainReferenceDataHeaderHashes.size)
    obj.mainchainReferenceDataHeaderHashes.foreach(id => w.putBytes(id.data))
    WithdrawalEpochInfoSerializer.serialize(obj.withdrawalEpochInfo, w)

    w.putOption(obj.vrfOutputOpt){case (writer: Writer, vrfOutput: VrfOutput) =>
      VrfOutputSerializer.getSerializer.serialize(vrfOutput, writer)
    }

    w.putBytes(idToBytes(obj.lastBlockInPreviousConsensusEpoch))
  }

  private def readMainchainHeadersHashes(r: Reader): Seq[MainchainHeaderHash] = {
    val references: ArrayBuffer[MainchainHeaderHash] = ArrayBuffer()
    val length = r.getInt()

    (0 until length).foreach(_ => {
      val bytes = r.getBytes(mainchainHeaderHashSize)
      references.append(byteArrayToMainchainHeaderHash(bytes))
    })

    references
  }

  override def parse(r: Reader): SidechainBlockInfo = {
    val height = r.getInt()
    val score = r.getLong()
    val parentId = bytesToId(r.getBytes(NodeViewModifier.ModifierIdSize))
    val timestamp = r.getLong()
    val semanticValidityCode = r.getByte()
    val mainchainHeaderHashes = readMainchainHeadersHashes(r)
    val mainchainReferenceDataHeaderHashes = readMainchainHeadersHashes(r)
    val withdrawalEpochInfo = WithdrawalEpochInfoSerializer.parse(r)
    val vrfOutputOpt = r.getOption(VrfOutputSerializer.getSerializer.parse(r))

    val lastBlockInPreviousConsensusEpoch = bytesToId(r.getBytes(NodeViewModifier.ModifierIdSize))

    SidechainBlockInfo(height, score, parentId, timestamp, ModifierSemanticValidity.restoreFromCode(semanticValidityCode),
      mainchainHeaderHashes, mainchainReferenceDataHeaderHashes, withdrawalEpochInfo, vrfOutputOpt, lastBlockInPreviousConsensusEpoch)
  }
} 
Example 110
Source File: IODBStoreAdapter.scala    From Sidechains-SDK   with MIT License 5 votes vote down vote up
package com.horizen.storage

import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Optional
import com.horizen.utils.Pair

import scala.collection.JavaConverters._

import io.iohk.iodb.Store
import com.horizen.utils.ByteArrayWrapper

import scala.collection.mutable.ArrayBuffer

class IODBStoreAdapter (store : Store)
  extends Storage {

  override def get(key: ByteArrayWrapper): Optional[ByteArrayWrapper] = {
    val value = store.get(key)
    if (value.isEmpty)
      Optional.empty()
    else
      Optional.of(new ByteArrayWrapper(value.get))
  }

  override def getOrElse(key: ByteArrayWrapper, defaultValue: ByteArrayWrapper): ByteArrayWrapper = {
    val value = store.get(key)
    if (value.isEmpty)
      defaultValue
    else
      new ByteArrayWrapper(value.get)
  }

  override def get(keys: JList[ByteArrayWrapper]): JList[Pair[ByteArrayWrapper, Optional[ByteArrayWrapper]]] = {
    val keysList = new ArrayBuffer[ByteArrayWrapper]()
    val valList = store.get(keys.asScala)
    val values = new JArrayList[Pair[ByteArrayWrapper,Optional[ByteArrayWrapper]]]()

    for (v <- valList)
      if (v._2.isDefined)
        values.add(new Pair[ByteArrayWrapper,Optional[ByteArrayWrapper]](new ByteArrayWrapper(v._1),
          Optional.of(new ByteArrayWrapper(v._2.get))))
      else
        values.add(new Pair[ByteArrayWrapper,Optional[ByteArrayWrapper]](new ByteArrayWrapper(v._1),
          Optional.empty()))

    values
  }

  override def getAll: JList[Pair[ByteArrayWrapper, ByteArrayWrapper]] = {
    val values = new JArrayList[Pair[ByteArrayWrapper,ByteArrayWrapper]]()

    for ( i <- store.getAll())
      values.add(new Pair[ByteArrayWrapper,ByteArrayWrapper](new ByteArrayWrapper(i._1),
        new ByteArrayWrapper(i._2)))

    values
  }

  override def lastVersionID(): Optional[ByteArrayWrapper] = {
    val value = store.lastVersionID
    if (value.isEmpty)
      Optional.empty()
    else
      Optional.of(new ByteArrayWrapper(value.get))
  }

  override def update(version: ByteArrayWrapper, toUpdate: JList[Pair[ByteArrayWrapper, ByteArrayWrapper]],
                      toRemove: JList[ByteArrayWrapper]): Unit = {

    val listToUpdate = new ArrayBuffer[Tuple2[ByteArrayWrapper,ByteArrayWrapper]]()

    for (r <- toUpdate.asScala) {
      listToUpdate.append(new Tuple2[ByteArrayWrapper, ByteArrayWrapper](r.getKey, r.getValue))
    }

    store.update(version, toRemove.asScala, listToUpdate)
  }

  override def rollback(version : ByteArrayWrapper): Unit = {
    store.rollback(version)
  }

  override def rollbackVersions(): JList[ByteArrayWrapper] = {
    val versions = store.rollbackVersions()
    val value = new JArrayList[ByteArrayWrapper]()
    for (v <- versions)
      value.add(new ByteArrayWrapper(v))

    value
  }

  override def isEmpty(): Boolean = !lastVersionID().isPresent

  override def close(): Unit = {
    store.close()
  }
} 
Example 111
Source File: StoreOpsTest.scala    From fs2-blobstore   with Apache License 2.0 5 votes vote down vote up
package blobstore

import java.nio.charset.Charset
import java.nio.file.Files
import java.util.concurrent.Executors

import cats.effect.{Blocker, IO}
import cats.effect.laws.util.TestInstances
import cats.implicits._
import fs2.Pipe
import org.scalatest.Assertion
import org.scalatest.flatspec.AnyFlatSpec
import implicits._
import org.scalatest.matchers.must.Matchers

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext


class StoreOpsTest extends AnyFlatSpec with Matchers with TestInstances {

  implicit val cs = IO.contextShift(ExecutionContext.global)
  val blocker = Blocker.liftExecutionContext(ExecutionContext.fromExecutor(Executors.newCachedThreadPool))

  behavior of "PutOps"
  it should "buffer contents and compute size before calling Store.put" in {
    val bytes: Array[Byte] = "AAAAAAAAAA".getBytes(Charset.forName("utf-8"))
    val store = DummyStore(_.size must be(Some(bytes.length)))

    fs2.Stream.emits(bytes).covary[IO].through(store.bufferedPut(Path("path/to/file.txt"), blocker)).compile.drain.unsafeRunSync()
    store.buf.toArray must be(bytes)

  }

  it should "upload a file from a nio Path" in {
    val bytes = "hello".getBytes(Charset.forName("utf-8"))
    val store = DummyStore(_.size must be(Some(bytes.length)))

    fs2.Stream.bracket(IO(Files.createTempFile("test-file", ".bin"))) { p =>
      IO(p.toFile.delete).void
    }.flatMap { p =>
      fs2.Stream.emits(bytes).covary[IO].through(fs2.io.file.writeAll(p, blocker)).drain ++
        fs2.Stream.eval(store.put(p, Path("path/to/file.txt"), blocker))
    }.compile.drain.unsafeRunSync()
    store.buf.toArray must be(bytes)
  }

}

final case class DummyStore(check: Path => Assertion) extends Store[IO] {
  val buf = new ArrayBuffer[Byte]()
  override def put(path: Path): Pipe[IO, Byte, Unit] = {
    check(path)
    in => {
      buf.appendAll(in.compile.toVector.unsafeRunSync())
      fs2.Stream.emit(())
    }
  }
  override def list(path: Path): fs2.Stream[IO, Path] = ???
  override def get(path: Path, chunkSize: Int): fs2.Stream[IO, Byte] = ???
  override def move(src: Path, dst: Path): IO[Unit] = ???
  override def copy(src: Path, dst: Path): IO[Unit] = ???
  override def remove(path: Path): IO[Unit] = ???
} 
Example 112
Source File: MetadataTransformUtils.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature.operator

import org.apache.spark.sql.types.{MetadataBuilder, StructField}

import scala.collection.mutable.ArrayBuffer


  def vectorCartesianTransform(fields: Array[StructField], numFeatures: Int): MetadataBuilder = {
    if (fields.length < 2) {
      throw new IllegalArgumentException("the number of cols in the input DataFrame should be no less than 2")
    }

    var res = Array[String]()
    if (fields.head.metadata.contains(DERIVATION)) {
      res = fields.head.metadata.getStringArray(DERIVATION)
    } else {
      res = createDerivation(numFeatures)
    }

    for (i <- 1 until fields.length) {
      if (fields(i).metadata.contains(DERIVATION)) {
        res = cartesianWithArray(res, fields(i).metadata.getStringArray(DERIVATION))
      } else {
        res = cartesianWithArray(res, createDerivation(numFeatures))
      }
    }

    val metadata = fields.last.metadata
    new MetadataBuilder().withMetadata(metadata).putStringArray(DERIVATION, res)
  }

} 
Example 113
Source File: Message.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.network.nio

import java.net.InetSocketAddress
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import com.google.common.base.Charsets.UTF_8

import org.apache.spark.util.Utils

private[nio] abstract class Message(val typ: Long, val id: Int) {
  var senderAddress: InetSocketAddress = null
  var started = false
  var startTime = -1L
  var finishTime = -1L
  var isSecurityNeg = false
  var hasError = false

  def size: Int

  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]

  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]

  def timeTaken(): String = (finishTime - startTime).toString + " ms"

  override def toString: String = {
    this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
  }
}


private[nio] object Message {
  val BUFFER_MESSAGE = 1111111111L

  var lastId = 1

  def getNewId(): Int = synchronized {
    lastId += 1
    if (lastId == 0) {
      lastId += 1
    }
    lastId
  }

  def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
    if (dataBuffers == null) {
      return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
    }
    if (dataBuffers.exists(_ == null)) {
      throw new Exception("Attempting to create buffer message with null buffer")
    }
    new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
  }

  def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
    createBufferMessage(dataBuffers, 0)

  def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
    if (dataBuffer == null) {
      //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区
      createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
    } else {
      createBufferMessage(Array(dataBuffer), ackId)
    }
  }

  def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
    createBufferMessage(dataBuffer, 0)

  def createBufferMessage(ackId: Int): BufferMessage = {
    createBufferMessage(new Array[ByteBuffer](0), ackId)
  }

  
  def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
    val exceptionString = Utils.exceptionString(exception)
    val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
    val errorMessage = createBufferMessage(serializedExceptionString, ackId)
    errorMessage.hasError = true
    errorMessage
  }

  def create(header: MessageChunkHeader): Message = {
    val newMessage: Message = header.typ match {
      case BUFFER_MESSAGE => new BufferMessage(header.id,
        //ByteBuffer.allocate在能够读和写之前,必须有一个缓冲区,用静态方法 allocate() 来分配缓冲区
        ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
    }
    newMessage.hasError = header.hasError
    newMessage.senderAddress = header.address
    newMessage
  }
} 
Example 114
Source File: ApplicationInfo.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.master

import java.util.Date

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.deploy.ApplicationDescription
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils

private[spark] class ApplicationInfo(
  val startTime: Long,
  val id: String,
  val desc: ApplicationDescription,
  val submitDate: Date,
  val driver: RpcEndpointRef,
  defaultCores: Int)
    extends Serializable {
  //枚举类型赋值
  @transient var state: ApplicationState.Value = _
  @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
  @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
  @transient var coresGranted: Int = _
  @transient var endTime: Long = _
  @transient var appSource: ApplicationSource = _

  // A cap on the number of executors this application can have at any given time.
  //执行者的数量这个应用程序可以在任何给定的时间
  // By default, this is infinite. Only after the first allocation request is issued by the
  // application will this be set to a finite value. This is used for dynamic allocation.
  //默认情况下,这是无限的,只有在应用程序发出第一个分配请求之后,这将被设置为有限的值,这用于动态分配
  @transient private[master] var executorLimit: Int = _

  @transient private var nextExecutorId: Int = _

  init() //初始化方法

  private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
    in.defaultReadObject()
    init()
  }
  
  private[deploy] def getExecutorLimit: Int = executorLimit

  def duration: Long = {
    if (endTime != -1) {
      endTime - startTime
    } else {
      System.currentTimeMillis() - startTime
    }
  }

} 
Example 115
Source File: Schedulable.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode


private[spark] trait Schedulable {
  var parent: Pool
  // child queues
  def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
  def schedulingMode: SchedulingMode
  def weight: Int
  def minShare: Int
  def runningTasks: Int
  def priority: Int
  def stageId: Int
  def name: String

  def addSchedulable(schedulable: Schedulable): Unit
  def removeSchedulable(schedulable: Schedulable): Unit
  def getSchedulableByName(name: String): Schedulable
  def executorLost(executorId: String, host: String): Unit
  def checkSpeculatableTasks(): Boolean
  def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
} 
Example 116
Source File: ByteArrayChunkOutputStream.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.io

import java.io.OutputStream

import scala.collection.mutable.ArrayBuffer



  private var position = chunkSize

  override def write(b: Int): Unit = {
    allocateNewChunkIfNeeded()
    //注意前套数组取值方式
    chunks(lastChunkIndex)(position) = b.toByte
    position += 1
  }

  override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
    var written = 0
    while (written < len) {
      allocateNewChunkIfNeeded()
      val thisBatch = math.min(chunkSize - position, len - written)
      System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
      written += thisBatch
      position += thisBatch
    }
  }

  @inline
  private def allocateNewChunkIfNeeded(): Unit = {
    if (position == chunkSize) {
      chunks += new Array[Byte](chunkSize)
      lastChunkIndex += 1
      position = 0
    }
  }

  def toArrays: Array[Array[Byte]] = {
    if (lastChunkIndex == -1) {
      new Array[Array[Byte]](0)
    } else {
      // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
      // An alternative would have been returning an array of ByteBuffers, with the last buffer
      // bounded to only the last chunk's position. However, given our use case in Spark (to put
      // the chunks in block manager), only limiting the view bound of the buffer would still
      // require the block manager to store the whole chunk.
      //将第一个n-1块复制到输出,然后创建一个适合最后一个块的数组。一个替代方法是返回一个ByteBuffers数组,最后一个缓冲区
      //仅限于最后一个块的位置。 但是,考虑到我们在Spark中的用例(put块块中的块管理器),只会限制缓冲区的视图边界
      //要求块管理器存储整个块。
      val ret = new Array[Array[Byte]](chunks.size)
      for (i <- 0 until chunks.size - 1) {
        ret(i) = chunks(i)
      }
      if (position == chunkSize) {
        ret(lastChunkIndex) = chunks(lastChunkIndex)
      } else {
        ret(lastChunkIndex) = new Array[Byte](position)
        System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
      }
      ret
    }
  }
} 
Example 117
Source File: MapPartitionsWithPreparationRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Partition, Partitioner, TaskContext}


  override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
    val prepared =
      if (preparedArguments.isEmpty) {
        preparePartition()
      } else {
        preparedArguments.remove(0)
      }
    val parentIterator = firstParent[T].iterator(partition, context)
    executePartition(context, partition.index, prepared, parentIterator)
  }
} 
Example 118
Source File: UnionRDD.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient rdd: RDD[T],
    val parentRddIndex: Int,
    @transient parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    //在任务序列化时更新对父拆分的引用
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdds.map(_.partitions.length).sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 119
Source File: hbaseCommands.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.hbase._
import org.apache.spark.sql.hbase.util.DataTypeUtils
import org.apache.spark.sql.types._

import scala.collection.mutable.ArrayBuffer

@DeveloperApi
case class AlterDropColCommand(namespace: String, tableName: String, columnName: String)
  extends RunnableCommand {

  def run(sparkSession: SparkSession): Seq[Row] = {
    sparkSession.sharedState.externalCatalog.asInstanceOf[HBaseCatalog]
      .alterTableDropNonKey(namespace, tableName, columnName)
    sparkSession.sharedState.externalCatalog.asInstanceOf[HBaseCatalog].stopAdmin()
    Seq.empty[Row]
  }
}

@DeveloperApi
case class AlterAddColCommand(namespace: String,
                              tableName: String,
                              colName: String,
                              colType: String,
                              colFamily: String,
                              colQualifier: String) extends RunnableCommand {

  def run(sparkSession: SparkSession): Seq[Row] = {
    val hbaseCatalog = sparkSession.sharedState.externalCatalog.asInstanceOf[HBaseCatalog]
    hbaseCatalog.alterTableAddNonKey(namespace, tableName,
      NonKeyColumn(colName, DataTypeUtils.getDataType(colType), colFamily, colQualifier))
    hbaseCatalog.stopAdmin()
    Seq.empty[Row]
  }
}

@DeveloperApi
case class InsertValueIntoTableCommand(tid: TableIdentifier, valueSeq: Seq[String])
  extends RunnableCommand {
  override def run(sparkSession: SparkSession) = {
    val relation: HBaseRelation = sparkSession.sessionState.catalog.externalCatalog
      .asInstanceOf[HBaseCatalog]
      .getHBaseRelation(tid.database.getOrElse(null), tid.table).getOrElse(null)

    val bytes = valueSeq.zipWithIndex.map(v =>
      DataTypeUtils.string2TypeData(v._1, relation.schema(v._2).dataType))

    val rows = sparkSession.sparkContext.makeRDD(Seq(Row.fromSeq(bytes)))
    val inputValuesDF = sparkSession.createDataFrame(rows, relation.schema)
    relation.insert(inputValuesDF, overwrite = false)

    Seq.empty[Row]
  }

  override def output: Seq[Attribute] = Seq.empty
} 
Example 120
Source File: MeetupReceiver.scala    From meetup-stream   with Apache License 2.0 5 votes vote down vote up
package receiver

import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.storage.StorageLevel
import org.apache.spark.Logging
import com.ning.http.client.AsyncHttpClientConfig
import com.ning.http.client._
import scala.collection.mutable.ArrayBuffer
import java.io.OutputStream
import java.io.ByteArrayInputStream
import java.io.InputStreamReader
import java.io.BufferedReader
import java.io.InputStream
import java.io.PipedInputStream
import java.io.PipedOutputStream

class MeetupReceiver(url: String) extends Receiver[String](StorageLevel.MEMORY_AND_DISK_2) with Logging {
  
  @transient var client: AsyncHttpClient = _
  
  @transient var inputPipe: PipedInputStream = _
  @transient var outputPipe: PipedOutputStream = _  
       
  def onStart() {    
    val cf = new AsyncHttpClientConfig.Builder()
    cf.setRequestTimeout(Integer.MAX_VALUE)
    cf.setReadTimeout(Integer.MAX_VALUE)
    cf.setPooledConnectionIdleTimeout(Integer.MAX_VALUE)      
    client= new AsyncHttpClient(cf.build())
    
    inputPipe = new PipedInputStream(1024 * 1024)
    outputPipe = new PipedOutputStream(inputPipe)
    val producerThread = new Thread(new DataConsumer(inputPipe))
    producerThread.start()
    
    client.prepareGet(url).execute(new AsyncHandler[Unit]{
        
      def onBodyPartReceived(bodyPart: HttpResponseBodyPart) = {
        bodyPart.writeTo(outputPipe)
        AsyncHandler.STATE.CONTINUE        
      }
      
      def onStatusReceived(status: HttpResponseStatus) = {
        AsyncHandler.STATE.CONTINUE
      }
      
      def onHeadersReceived(headers: HttpResponseHeaders) = {
        AsyncHandler.STATE.CONTINUE
      }
            
      def onCompleted = {
        println("completed")
      }
      
      
      def onThrowable(t: Throwable)={
        t.printStackTrace()
      }
        
    })    
    
    
  }

  def onStop() {
    if (Option(client).isDefined) client.close()
    if (Option(outputPipe).isDefined) {
     outputPipe.flush()
     outputPipe.close() 
    }
    if (Option(inputPipe).isDefined) {
     inputPipe.close() 
    }    
  }
  
  class DataConsumer(inputStream: InputStream) extends Runnable 
  {
       
      override
      def run()
      {        
        val bufferedReader = new BufferedReader( new InputStreamReader( inputStream ))
        var input=bufferedReader.readLine()
        while(input!=null){          
          store(input)
          input=bufferedReader.readLine()
        }            
      }  
      
  }

} 
Example 121
Source File: HashBasedDeduplicator.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.odkl.texts

import odkl.analysis.spark.util.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.linalg.Vectors.norm
import org.apache.spark.ml.linalg.{BLAS, Vector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.collection.mutable.ArrayBuffer


  def setSimilarityTreshold(value: Double): this.type = set(similarityThreshold, value)

  setDefault(new ParamPair[String](inputColHash,"hash"),
    new ParamPair[Double](similarityThreshold,0.9))

  def this() = this(Identifiable.randomUID("hashBasedDeduplication"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.sqlContext.createDataFrame(
      dataset.toDF
        .repartition(dataset.col($(inputColHash)))
        .sortWithinPartitions($(inputColHash))
        .rdd
        .mapPartitions((f: Iterator[Row]) => {
          if (f.hasNext) {
            var curHash: Long = -1L
            val vectorsBuffer = new ArrayBuffer[Vector](0) // unique vectors buffer for this bucket
            for (it <- f) yield {
              val newHash = it.getAs[Long]($(inputColHash))
              if (newHash == curHash) {
                val currentVector = it.getAs[Vector]($(inputColVector))
                val isUnique = vectorsBuffer.forall(storedVector => { //are this vector is "different" with other in buffer?
                  (BLAS.dot(storedVector, currentVector) / (norm(storedVector, 2) * norm(currentVector, 2))) < $(similarityThreshold) //is unsimilar?
                })
                if (isUnique) {
                  vectorsBuffer.append(currentVector)
                  it
                } else {
                  Row.empty //dummy Row
                }
              } else {
                vectorsBuffer.clear()
                vectorsBuffer.append(it.getAs[Vector]($(inputColVector)))
                curHash = newHash
                it
              }
            }
          } else {
            new Array[Row](0).toIterator //empty partition?
          }

        }).filter(!_.equals(Row.empty)), //filter dummy
      transformSchema(dataset.schema))
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType = {
    schema
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)


} 
Example 122
Source File: NonSampleCompactor.scala    From deequ   with Apache License 2.0 5 votes vote down vote up
package com.amazon.deequ.analyzers

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.util.Random


    val output = (offset until len by 2).map(sortedBuffer(_)).toArray
    val tail = findOdd(items)
    items = items % 2
    var newBuffer = ArrayBuffer[T]()
    if (tail.isDefined) {
      newBuffer = newBuffer :+ tail.get
    }
    buffer = newBuffer
    numOfCompress = numOfCompress + 1
    output
  }
} 
Example 123
Source File: ScaleAndConvert.scala    From SparkNet   with MIT License 5 votes vote down vote up
package preprocessing

import java.awt.image.DataBufferByte
import java.io.ByteArrayInputStream
import javax.imageio.ImageIO

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import net.coobird.thumbnailator._

import org.apache.spark.rdd.RDD

import libs._

object ScaleAndConvert {
  def BufferedImageToByteArray(image: java.awt.image.BufferedImage) : Array[Byte] = {
    val height = image.getHeight()
    val width = image.getWidth()
    val pixels = image.getRGB(0, 0, width, height, null, 0, width)
    val result = new Array[Byte](3 * height * width)
    var row = 0
    while (row < height) {
      var col = 0
      while (col < width) {
        val rgb = pixels(row * width + col)
        result(0 * height * width + row * width + col) = ((rgb >> 16) & 0xFF).toByte
        result(1 * height * width + row * width + col) = ((rgb >> 8) & 0xFF).toByte
        result(2 * height * width + row * width + col) = (rgb & 0xFF).toByte
        col += 1
      }
      row += 1
    }
    result
  }

  def decompressImageAndResize(compressedImage: Array[Byte], height: Int, width: Int) : Option[Array[Byte]] = {
    // this method takes a JPEG, decompresses it, and resizes it
    try {
      val im = ImageIO.read(new ByteArrayInputStream(compressedImage))
      val resizedImage = Thumbnails.of(im).forceSize(width, height).asBufferedImage()
      Some(BufferedImageToByteArray(resizedImage))
    } catch {
      // If images can't be processed properly, just ignore them
      case e: java.lang.IllegalArgumentException => None
      case e: javax.imageio.IIOException => None
      case e: java.lang.NullPointerException => None
    }
  }
} 
Example 124
Source File: ClassRDDPartitioner.scala    From spark-orientdb-connector   with Apache License 2.0 5 votes vote down vote up
package com.metreta.spark.orientdb.connector.rdd.partitioner

import scala.collection.JavaConversions.iterableAsScalaIterable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.Partition

import com.metreta.spark.orientdb.connector.api.OrientDBConnector
import com.orientechnologies.orient.core.metadata.schema.OClass
import com.orientechnologies.orient.core.metadata.schema.OSchema
import com.orientechnologies.orient.core.storage.OStorage
import com.metreta.spark.orientdb.connector.SystemTables
import scala.collection.JavaConversions.iterableAsScalaIterable


  def getPartitions(): Array[Partition] = {
    
	val db = connector.databaseDocumentTx()

    var partitions = new ArrayBuffer[OrientPartition]
    val schema: OSchema = connector.getSchema(db)
    var klass: OClass = schema.getClass(mClass)
    val storage: OStorage = connector.getStorage(db)
    klass.getClusterIds.zipWithIndex foreach {
      case (clusterId, index) => partitions = partitions.+=(OrientPartition(
        index,
        null, // <- Host Address ?????
        PartitionName(klass.getName, storage.getClusterById(clusterId).getName)))
    }
  partitions.toArray
  }

} 
Example 125
Source File: SparkContextFunctionsSpec.scala    From spark-orientdb-connector   with Apache License 2.0 5 votes vote down vote up
package com.metreta.spark.orientdb.connector

import scala.collection.mutable.ArrayBuffer
import org.scalatest.BeforeAndAfterAll
import com.orientechnologies.orient.core.id.ORID
import com.metreta.spark.orientdb.connector.utils.BaseOrientDbFlatSpec

class SparkContextFunctionsSpec extends BaseOrientDbFlatSpec {

  var oridList: ArrayBuffer[String] = new ArrayBuffer
  var MaxCluster = 1000
  var MaxRecord = 1000

  override def beforeAll(): Unit = {
    initSparkConf(defaultSparkConf)
    createOridList()
  }

  override def afterAll(): Unit = {
    sparkContext.stop()
  }

  "A VertexId created from RID" should "be unique" in {
    val vertexIdList = oridList map { rid => sparkContext.getVertexIdFromString(rid) }
    val duplicatedValues = vertexIdList.groupBy(identity).collect { case (x, ys) if ys.lengthCompare(1) > 0 => x }
    duplicatedValues shouldBe empty
  }

  it should "be a positive number" in {
    val negativeValues = oridList filter { rid => sparkContext.getVertexIdFromString(rid) < 0 }
    negativeValues shouldBe empty
  }

  def createOridList() {
    for (clusterId <- 0 to MaxCluster) {
      for (recordId <- 0 to MaxRecord) {
        val rid = new StringBuilder
        rid.append(ORID.PREFIX);
        rid.append(clusterId);
        rid.append(ORID.SEPARATOR);
        rid.append(recordId);
        oridList += rid.toString
      }
    }
  }

} 
Example 126
Source File: SpearmanCorrelation.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat.correlation

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.rdd.RDD


  override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
    // ((columnIndex, value), rowUid)
    val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
      vec.toArray.view.zipWithIndex.map { case (v, j) =>
        ((j, v), uid)
      }
    }
    // global sort by (columnIndex, value)
    val sorted = colBased.sortByKey()
    // assign global ranks (using average ranks for tied values)
    val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
      var preCol = -1
      var preVal = Double.NaN
      var startRank = -1.0
      val cachedUids = ArrayBuffer.empty[Long]
      val flush: () => Iterable[(Long, (Int, Double))] = () => {
        val averageRank = startRank + (cachedUids.size - 1) / 2.0
        val output = cachedUids.map { uid =>
          (uid, (preCol, averageRank))
        }
        cachedUids.clear()
        output
      }
      iter.flatMap { case (((j, v), uid), rank) =>
        // If we see a new value or cachedUids is too big, we flush ids with their average rank.
        if (j != preCol || v != preVal || cachedUids.size >= 10000000) {
          val output = flush()
          preCol = j
          preVal = v
          startRank = rank
          cachedUids += uid
          output
        } else {
          cachedUids += uid
          Iterator.empty
        }
      } ++ flush()
    }
    // Replace values in the input matrix by their ranks compared with values in the same column.
    // Note that shifting all ranks in a column by a constant value doesn't affect result.
    val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
      // sort by column index and then convert values to a vector
      Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
    }
    PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
  }
} 
Example 127
Source File: ApplicationMasterArguments.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.mutable.ArrayBuffer

class ApplicationMasterArguments(val args: Array[String]) {
  var userJar: String = null
  var userClass: String = null
  var primaryPyFile: String = null
  var primaryRFile: String = null
  var userArgs: Seq[String] = Nil
  var propertiesFile: String = null

  parseArgs(args.toList)

  private def parseArgs(inputArgs: List[String]): Unit = {
    val userArgsBuffer = new ArrayBuffer[String]()

    var args = inputArgs

    while (!args.isEmpty) {
      // --num-workers, --worker-memory, and --worker-cores are deprecated since 1.0,
      // the properties with executor in their names are preferred.
      args match {
        case ("--jar") :: value :: tail =>
          userJar = value
          args = tail

        case ("--class") :: value :: tail =>
          userClass = value
          args = tail

        case ("--primary-py-file") :: value :: tail =>
          primaryPyFile = value
          args = tail

        case ("--primary-r-file") :: value :: tail =>
          primaryRFile = value
          args = tail

        case ("--arg") :: value :: tail =>
          userArgsBuffer += value
          args = tail

        case ("--properties-file") :: value :: tail =>
          propertiesFile = value
          args = tail

        case _ =>
          printUsageAndExit(1, args)
      }
    }

    if (primaryPyFile != null && primaryRFile != null) {
      // scalastyle:off println
      System.err.println("Cannot have primary-py-file and primary-r-file at the same time")
      // scalastyle:on println
      System.exit(-1)
    }

    userArgs = userArgsBuffer.toList
  }

  def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
    // scalastyle:off println
    if (unknownParam != null) {
      System.err.println("Unknown/unsupported param " + unknownParam)
    }
    System.err.println("""
      |Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options]
      |Options:
      |  --jar JAR_PATH       Path to your application's JAR file
      |  --class CLASS_NAME   Name of your application's main class
      |  --primary-py-file    A main Python file
      |  --primary-r-file     A main R file
      |  --arg ARG            Argument to be passed to your application's main class.
      |                       Multiple invocations are possible, each will be passed in order.
      |  --properties-file FILE Path to a custom Spark properties file.
      """.stripMargin)
    // scalastyle:on println
    System.exit(exitCode)
  }
}

object ApplicationMasterArguments {
  val DEFAULT_NUMBER_EXECUTORS = 2
} 
Example 128
Source File: ClientArguments.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy.yarn

import scala.collection.mutable.ArrayBuffer

// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
private[spark] class ClientArguments(args: Array[String]) {

  var userJar: String = null
  var userClass: String = null
  var primaryPyFile: String = null
  var primaryRFile: String = null
  var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()

  parseArgs(args.toList)

  private def parseArgs(inputArgs: List[String]): Unit = {
    var args = inputArgs

    while (!args.isEmpty) {
      args match {
        case ("--jar") :: value :: tail =>
          userJar = value
          args = tail

        case ("--class") :: value :: tail =>
          userClass = value
          args = tail

        case ("--primary-py-file") :: value :: tail =>
          primaryPyFile = value
          args = tail

        case ("--primary-r-file") :: value :: tail =>
          primaryRFile = value
          args = tail

        case ("--arg") :: value :: tail =>
          userArgs += value
          args = tail

        case Nil =>

        case _ =>
          throw new IllegalArgumentException(getUsageMessage(args))
      }
    }

    if (primaryPyFile != null && primaryRFile != null) {
      throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" +
        " at the same time")
    }
  }

  private def getUsageMessage(unknownParam: List[String] = null): String = {
    val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
    message +
      s"""
      |Usage: org.apache.spark.deploy.yarn.Client [options]
      |Options:
      |  --jar JAR_PATH           Path to your application's JAR file (required in yarn-cluster
      |                           mode)
      |  --class CLASS_NAME       Name of your application's main class (required)
      |  --primary-py-file        A main Python file
      |  --primary-r-file         A main R file
      |  --arg ARG                Argument to be passed to your application's main class.
      |                           Multiple invocations are possible, each will be passed in order.
      """.stripMargin
  }
} 
Example 129
Source File: KPLBasedKinesisTestUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kinesis

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

private[kinesis] class KPLBasedKinesisTestUtils(streamShardCount: Int = 2)
    extends KinesisTestUtils(streamShardCount) {
  override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
    if (!aggregate) {
      new SimpleDataGenerator(kinesisClient)
    } else {
      new KPLDataGenerator(regionName)
    }
  }
}


private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {

  private lazy val producer: KPLProducer = {
    val conf = new KinesisProducerConfiguration()
      .setRecordMaxBufferedTime(1000)
      .setMaxConnections(1)
      .setRegion(regionName)
      .setMetricsLevel("none")

    new KPLProducer(conf)
  }

  override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
    val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
    data.foreach { num =>
      val str = num.toString
      val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8))
      val future = producer.addUserRecord(streamName, str, data)
      val kinesisCallBack = new FutureCallback[UserRecordResult]() {
        override def onFailure(t: Throwable): Unit = {} // do nothing

        override def onSuccess(result: UserRecordResult): Unit = {
          val shardId = result.getShardId
          val seqNumber = result.getSequenceNumber()
          val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
            new ArrayBuffer[(Int, String)]())
          sentSeqNumbers += ((num, seqNumber))
        }
      }
      Futures.addCallback(future, kinesisCallBack)
    }
    producer.flushSync()
    shardIdToSeqNumbers.toMap
  }
} 
Example 130
Source File: Exchange.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 131
Source File: SQLAppStatusStore.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.ui

import java.lang.{Long => JLong}
import java.util.Date

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.databind.annotation.JsonDeserialize

import org.apache.spark.JobExecutionStatus
import org.apache.spark.status.KVUtils.KVIndexParam
import org.apache.spark.util.kvstore.{KVIndex, KVStore}


class SparkPlanGraphNodeWrapper(
    val node: SparkPlanGraphNode,
    val cluster: SparkPlanGraphClusterWrapper) {

  def toSparkPlanGraphNode(): SparkPlanGraphNode = {
    assert(node == null ^ cluster == null, "One and only of of nore or cluster must be set.")
    if (node != null) node else cluster.toSparkPlanGraphCluster()
  }

}

case class SQLPlanMetric(
    name: String,
    accumulatorId: Long,
    metricType: String) 
Example 132
Source File: ManifestFileCommitProtocol.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.util.UUID

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage


  def setupManifestOptions(fileLog: FileStreamSinkLog, batchId: Long): Unit = {
    this.fileLog = fileLog
    this.batchId = batchId
  }

  override def setupJob(jobContext: JobContext): Unit = {
    require(fileLog != null, "setupManifestOptions must be called before this function")
    // Do nothing
  }

  override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
    require(fileLog != null, "setupManifestOptions must be called before this function")
    val fileStatuses = taskCommits.flatMap(_.obj.asInstanceOf[Seq[SinkFileStatus]]).toArray

    if (fileLog.add(batchId, fileStatuses)) {
      logInfo(s"Committed batch $batchId")
    } else {
      throw new IllegalStateException(s"Race while writing batch $batchId")
    }
  }

  override def abortJob(jobContext: JobContext): Unit = {
    require(fileLog != null, "setupManifestOptions must be called before this function")
    // Do nothing
  }

  override def setupTask(taskContext: TaskAttemptContext): Unit = {
    addedFiles = new ArrayBuffer[String]
  }

  override def newTaskTempFile(
      taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
    // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
    // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
    // the file name is fine and won't overflow.
    val split = taskContext.getTaskAttemptID.getTaskID.getId
    val uuid = UUID.randomUUID.toString
    val filename = f"part-$split%05d-$uuid$ext"

    val file = dir.map { d =>
      new Path(new Path(path, d), filename).toString
    }.getOrElse {
      new Path(path, filename).toString
    }

    addedFiles += file
    file
  }

  override def newTaskTempFileAbsPath(
      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
    throw new UnsupportedOperationException(
      s"$this does not support adding files with an absolute path")
  }

  override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
    if (addedFiles.nonEmpty) {
      val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration)
      val statuses: Seq[SinkFileStatus] =
        addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f))))
      new TaskCommitMessage(statuses)
    } else {
      new TaskCommitMessage(Seq.empty[SinkFileStatus])
    }
  }

  override def abortTask(taskContext: TaskAttemptContext): Unit = {
    // Do nothing
    // TODO: we can also try delete the addedFiles as a best-effort cleanup.
  }
} 
Example 133
Source File: BatchEvalPythonExecSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  override def beforeAll(): Unit = {
    super.beforeAll()
    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
  }

  override def afterAll(): Unit = {
    spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF"))
    super.afterAll()
  }

  test("Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: AttributeReference),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Nested Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: no push down on non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: push down on deterministic predicates after the first non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4")

    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF refers to the attributes from more than one child") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = Seq(("Hello", 4)).toDF("c", "d")
    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
      case b: BatchEvalPythonExec => b
    }
    assert(qualifiedPlanNodes.size == 1)
  }
}

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
  command = Array[Byte](),
  envVars = Map("" -> "").asJava,
  pythonIncludes = ArrayBuffer("").asJava,
  pythonExec = "",
  pythonVer = "",
  broadcastVars = null,
  accumulator = null)

class MyDummyPythonUDF extends UserDefinedPythonFunction(
  name = "dummyUDF",
  func = new DummyUDF,
  dataType = BooleanType,
  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
  udfDeterministic = true) 
Example 134
Source File: UnionDStream.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}

private[streaming]
class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
  extends DStream[T](parents.head.ssc) {

  require(parents.length > 0, "List of DStreams to union is empty")
  require(parents.map(_.ssc).distinct.length == 1, "Some of the DStreams have different contexts")
  require(parents.map(_.slideDuration).distinct.length == 1,
    "Some of the DStreams have different slide durations")

  override def dependencies: List[DStream[_]] = parents.toList

  override def slideDuration: Duration = parents.head.slideDuration

  override def compute(validTime: Time): Option[RDD[T]] = {
    val rdds = new ArrayBuffer[RDD[T]]()
    parents.map(_.getOrCompute(validTime)).foreach {
      case Some(rdd) => rdds += rdd
      case None => throw new SparkException("Could not generate RDD from a parent for unifying at" +
        s" time $validTime")
    }
    if (rdds.nonEmpty) {
      Some(ssc.sc.union(rdds))
    } else {
      None
    }
  }
} 
Example 135
Source File: QueueInputDStream.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable.{ArrayBuffer, Queue}
import scala.reflect.ClassTag

import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming.{StreamingContext, Time}

private[streaming]
class QueueInputDStream[T: ClassTag](
    ssc: StreamingContext,
    val queue: Queue[RDD[T]],
    oneAtATime: Boolean,
    defaultRDD: RDD[T]
  ) extends InputDStream[T](ssc) {

  override def start() { }

  override def stop() { }

  private def readObject(in: ObjectInputStream): Unit = {
    throw new NotSerializableException("queueStream doesn't support checkpointing. " +
      "Please don't use queueStream when checkpointing is enabled.")
  }

  private def writeObject(oos: ObjectOutputStream): Unit = {
    logWarning("queueStream doesn't support checkpointing")
  }

  override def compute(validTime: Time): Option[RDD[T]] = {
    val buffer = new ArrayBuffer[RDD[T]]()
    queue.synchronized {
      if (oneAtATime && queue.nonEmpty) {
        buffer += queue.dequeue()
      } else {
        buffer ++= queue
        queue.clear()
      }
    }
    if (buffer.nonEmpty) {
      if (oneAtATime) {
        Some(buffer.head)
      } else {
        Some(new UnionRDD(context.sc, buffer.toSeq))
      }
    } else if (defaultRDD != null) {
      Some(defaultRDD)
    } else {
      Some(ssc.sparkContext.emptyRDD)
    }
  }

} 
Example 136
Source File: LocalSparkCluster.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkConf
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils


    for (workerNum <- 1 to numWorkers) {
      val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
        memoryPerWorker, masters, null, Some(workerNum), _conf)
      workerRpcEnvs += workerEnv
    }

    masters
  }

  def stop() {
    logInfo("Shutting down local Spark cluster.")
    // Stop the workers before the master so they don't get upset that it disconnected
    workerRpcEnvs.foreach(_.shutdown())
    masterRpcEnvs.foreach(_.shutdown())
    workerRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.clear()
    workerRpcEnvs.clear()
  }
} 
Example 137
Source File: TaskResult.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{AccumulatorV2, Utils}

// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]


  def value(resultSer: SerializerInstance = null): T = {
    if (valueObjectDeserialized) {
      valueObject
    } else {
      // This should not run when holding a lock because it may cost dozens of seconds for a large
      // value
      val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer
      valueObject = ser.deserialize(valueBytes)
      valueObjectDeserialized = true
      valueObject
    }
  }
} 
Example 138
Source File: Schedulable.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode


private[spark] trait Schedulable {
  var parent: Pool
  // child queues
  def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
  def schedulingMode: SchedulingMode
  def weight: Int
  def minShare: Int
  def runningTasks: Int
  def priority: Int
  def stageId: Int
  def name: String

  def addSchedulable(schedulable: Schedulable): Unit
  def removeSchedulable(schedulable: Schedulable): Unit
  def getSchedulableByName(name: String): Schedulable
  def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit
  def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean
  def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
} 
Example 139
Source File: ChunkedByteBufferOutputStream.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.io

import java.io.OutputStream
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.storage.StorageUtils


  private[this] var position = chunkSize
  private[this] var _size = 0
  private[this] var closed: Boolean = false

  def size: Long = _size

  override def close(): Unit = {
    if (!closed) {
      super.close()
      closed = true
    }
  }

  override def write(b: Int): Unit = {
    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
    allocateNewChunkIfNeeded()
    chunks(lastChunkIndex).put(b.toByte)
    position += 1
    _size += 1
  }

  override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
    var written = 0
    while (written < len) {
      allocateNewChunkIfNeeded()
      val thisBatch = math.min(chunkSize - position, len - written)
      chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
      written += thisBatch
      position += thisBatch
    }
    _size += len
  }

  @inline
  private def allocateNewChunkIfNeeded(): Unit = {
    if (position == chunkSize) {
      chunks += allocator(chunkSize)
      lastChunkIndex += 1
      position = 0
    }
  }

  def toChunkedByteBuffer: ChunkedByteBuffer = {
    require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
    require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
    toChunkedByteBufferWasCalled = true
    if (lastChunkIndex == -1) {
      new ChunkedByteBuffer(Array.empty[ByteBuffer])
    } else {
      // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
      // An alternative would have been returning an array of ByteBuffers, with the last buffer
      // bounded to only the last chunk's position. However, given our use case in Spark (to put
      // the chunks in block manager), only limiting the view bound of the buffer would still
      // require the block manager to store the whole chunk.
      val ret = new Array[ByteBuffer](chunks.size)
      for (i <- 0 until chunks.size - 1) {
        ret(i) = chunks(i)
        ret(i).flip()
      }
      if (position == chunkSize) {
        ret(lastChunkIndex) = chunks(lastChunkIndex)
        ret(lastChunkIndex).flip()
      } else {
        ret(lastChunkIndex) = allocator(position)
        chunks(lastChunkIndex).flip()
        ret(lastChunkIndex).put(chunks(lastChunkIndex))
        ret(lastChunkIndex).flip()
        StorageUtils.dispose(chunks(lastChunkIndex))
      }
      new ChunkedByteBuffer(ret)
    }
  }
} 
Example 140
Source File: UnionRDD.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

object UnionRDD {
  private[spark] lazy val partitionEvalTaskSupport =
    new ForkJoinTaskSupport(new ForkJoinPool(8))
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  // visible for testing
  private[spark] val isPartitionListingParallel: Boolean =
    rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

  override def getPartitions: Array[Partition] = {
    val parRDDs = if (isPartitionListingParallel) {
      val parArray = rdds.par
      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
      parArray
    } else {
      rdds
    }
    val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 141
Source File: ValueJsonConversionTest.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.compiler.sql.driver

import ingraph.compiler.sql.driver.ValueJsonConversion._
import ingraph.compiler.sql.driver.ValueJsonConversionTest._
import org.neo4j.driver.internal.value._
import org.neo4j.driver.internal.{InternalNode, InternalPath, InternalRelationship}
import org.neo4j.driver.v1.Value
import org.scalactic.source
import org.scalactic.source.Position
import org.scalatest.FunSuite

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class ValueJsonConversionTest extends FunSuite {
  testParameters.foreach { case (value, testName, pos) =>
    test(testName) {
      println(value)

      val jsonString = gson.toJson(value, classOf[Value])
      println(jsonString)

      val deserialized = gson.fromJson(jsonString, classOf[Value])

      assert(value == deserialized)
    }(pos)
  }
}

object ValueJsonConversionTest {
  val testValues: ArrayBuffer[Value] = ArrayBuffer.empty
  val testParameters: ArrayBuffer[(Value, String, Position)] = ArrayBuffer.empty

  def addTest(value: Value, testName: String = null)(implicit pos: source.Position): Unit = {
    testValues += value
    testParameters += ((value, Option(testName).getOrElse(value.getClass.getSimpleName), pos))
  }

  private val stringValue = new StringValue("John")
  private val integerValue = new IntegerValue(101)
  private val propertiesMap = Map[String, Value]("name" -> stringValue).asJava

  addTest(new MapValue(propertiesMap))
  addTest(new BytesValue(Array[Byte](0, 42, 127, -128)))
  addTest(new ListValue(stringValue, integerValue))
  addTest(new NodeValue(new InternalNode(5, List("Label1", "Label2").asJavaCollection, propertiesMap)))
  addTest(new RelationshipValue(new InternalRelationship(42, 10, 20, "Edge_Type_1", propertiesMap)))
  addTest(new PathValue(new InternalPath(
    new InternalNode(0),
    new InternalRelationship(101, 0, 1, "TYPE_A"),
    new InternalNode(1)
  )))
  addTest(BooleanValue.FALSE)
  addTest(BooleanValue.TRUE)
  addTest(NullValue.NULL)
  addTest(stringValue)
  addTest(integerValue)
  addTest(new FloatValue(3.14))
} 
Example 142
Source File: TokenStreamUtils.scala    From odinson   with Apache License 2.0 5 votes vote down vote up
package ai.lum.odinson.lucene.analysis

import scala.collection.mutable.ArrayBuffer
import org.apache.lucene.analysis.Analyzer
import org.apache.lucene.analysis.TokenStream
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.apache.lucene.search.IndexSearcher
import org.apache.lucene.search.highlight.TokenSources

object TokenStreamUtils {

  def getTokens(
    docID: Int,
    fieldName: String,
    indexSearcher: IndexSearcher,
    analyzer: Analyzer
  ): Array[String] = {
    val doc = indexSearcher.doc(docID)
    val tvs = indexSearcher.getIndexReader().getTermVectors(docID)
    val text = doc.getField(fieldName).stringValue
    val ts = TokenSources.getTokenStream(fieldName, tvs, text, analyzer, -1)
    val tokens = getTokens(ts)
    tokens
  }

  def getTokens(ts: TokenStream): Array[String] = {
    ts.reset()
    val terms = new ArrayBuffer[String]

    while (ts.incrementToken()) {
      val charTermAttribute = ts.addAttribute(classOf[CharTermAttribute])
      val term = charTermAttribute.toString
      terms += term
    }

    ts.end()
    ts.close()

    terms.toArray
  }

} 
Example 143
Source File: Driver.scala    From OnlineLDA_Spark   with Apache License 2.0 5 votes vote down vote up
package com.github.yuhao.yang

import java.util.Calendar
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
import scala.collection.mutable.ArrayBuffer

object Driver extends Serializable{

  def main(args: Array[String]) {
    Logger.getLogger("org").setLevel(Level.ERROR)
    Logger.getLogger("akka").setLevel(Level.ERROR)
    val inputDir = args(0)
    val filePaths = extractPaths(inputDir + "texts", true)
    val stopWordsPath = inputDir + "stop.txt"
    val vocabPath = inputDir + "wordsEn.txt"

    println("begin: " + Calendar.getInstance().getTime)
    println("path size: " + filePaths.size)
    assert(filePaths.size > 0)

    val conf = new SparkConf().setAppName("online LDA Spark")
    val sc = new SparkContext(conf)

    val vocab = Docs2Vec.extractVocab(sc, Seq(vocabPath), stopWordsPath)
    val vocabArray = vocab.map(_.swap)

    val K = args(1).toInt
//    val lda = OnlineLDA_Spark.runBatchMode(sc, filePaths, vocab, K, 50)
    val lda = OnlineLDA_Spark.runOnlineMode(sc, filePaths, vocab, K, args(2).toInt)

    println("_lambda:")
    for(row <- 0 until lda._lambda.rows){
      val v = lda._lambda(row, ::).t
      val topk = lda._lambda(row, ::).t.argtopk(10)
      val pairs = topk.map(k => (vocabArray(k), v(k)))
      val sorted = pairs.sortBy(_._2).reverse
       println(sorted.map(x => (x._1)).mkString(","), sorted.map(x => ("%2.2f".format(x._2))).mkString(","))
    }

    println("end: " + Calendar.getInstance().getTime())

  }

  def extractPaths(path: String, recursive: Boolean = true): Array[String] ={
    val docsets = ArrayBuffer[String]()
    val fileList = new java.io.File(path).listFiles()
    if(fileList == null) return docsets.toArray
    for(f <- fileList){
      if(f.isDirectory){
        if(recursive)
          docsets ++= extractPaths(f.getAbsolutePath, true)
      }
      else{
        docsets +=  f.getAbsolutePath
      }
    }
    docsets.toArray
  }

} 
Example 144
Source File: QuerySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite

import scala.collection.mutable.ArrayBuffer

abstract class QuerySuite extends FunSuite with Logging {

  case class TestCase(program: String, query: String, data: Map[String, Seq[String]], answers: Seq[String], answersSize: Int) {
    def this(program: String, query: String, data: Map[String, Seq[String]], answersSize: Int) = this(program, query, data, null, answersSize)

    def this(program: String, query: String, data: Map[String, Seq[String]], answers: Seq[String]) = this(program, query, data, answers, answers.size)
  }

  def runTest(testCase: TestCase): Unit = runTests(Seq(testCase))

  def runTests(testCases: Seq[TestCase]): Unit = {
    val sparkCtx = new SparkContext("local[*]", "QuerySuite", new SparkConf()
      .set("spark.eventLog.enabled", "true")
      //.set("spark.eventLog.dir", "../logs")
      .set("spark.ui.enabled", "false")
      .set("spark.sql.shuffle.partitions", "5")
      .setAll(Map.empty[String, String])
    )

    val bigDatalogCtx = new BigDatalogContext(sparkCtx)

    var count: Int = 1
    for (testCase <- testCases) {
      bigDatalogCtx.loadProgram(testCase.program)

      for ((relationName, data) <- testCase.data) {
        val relationInfo = bigDatalogCtx.relationCatalog.getRelationInfo(relationName)
        if (relationInfo == null)
          throw new SparkException("You are attempting to load an unknown relation.")

        bigDatalogCtx.registerAndLoadTable(relationName, data, bigDatalogCtx.conf.numShufflePartitions)
      }

      val query = testCase.query
      val answers = testCase.answers
      logInfo("========== START BigDatalog Query " + count + " START ==========")
      val program = bigDatalogCtx.query(query)

      val results = program.execute().collect()

      // for some test cases we will only know the size of the answer set, not the actual answers
      if (answers == null) {
        assert(results.size == testCase.answersSize)
      } else {
        if (results.size != answers.size) {
          displayDifferences(results.map(_.toString), answers)
          // yes this will fail
          assert(results.size == answers.size)
        } else {
          for (result <- results)
            assert(answers.contains(result.toString()))
        }

        val resultStrings = results.map(_.toString).toSet

        for (answer <- answers)
          assert(resultStrings.contains(answer.toString()))
      }
      logInfo("========== END BigDatalog Query " + count + " END ==========\n")
      count += 1
      bigDatalogCtx.reset()
    }

    sparkCtx.stop()
  }

  private def displayDifferences(results: Seq[String], answers: Seq[String]): Unit = {
    val missingAnswers = new ArrayBuffer[String]
    val missingResults = new ArrayBuffer[String]

    for (result <- results)
      if (!answers.contains(result))
        missingAnswers += result

    for (answer <- answers)
      if (!results.contains(answer))
        missingResults += answer

    if (missingAnswers.nonEmpty)
      logInfo("Results not in Answers: " + missingAnswers.mkString(", "))

    if (missingResults.nonEmpty)
      logInfo("Answers not in Results: " + missingResults.mkString(", "))
  }
} 
Example 145
Source File: SpearmanCorrelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.stat.correlation

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.rdd.RDD


  override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
    // ((columnIndex, value), rowUid)
    val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
      vec.toArray.view.zipWithIndex.map { case (v, j) =>
        ((j, v), uid)
      }
    }
    // global sort by (columnIndex, value)
    val sorted = colBased.sortByKey()
    // assign global ranks (using average ranks for tied values)
    val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
      var preCol = -1
      var preVal = Double.NaN
      var startRank = -1.0
      var cachedUids = ArrayBuffer.empty[Long]
      val flush: () => Iterable[(Long, (Int, Double))] = () => {
        val averageRank = startRank + (cachedUids.size - 1) / 2.0
        val output = cachedUids.map { uid =>
          (uid, (preCol, averageRank))
        }
        cachedUids.clear()
        output
      }
      iter.flatMap { case (((j, v), uid), rank) =>
        // If we see a new value or cachedUids is too big, we flush ids with their average rank.
        if (j != preCol || v != preVal || cachedUids.size >= 10000000) {
          val output = flush()
          preCol = j
          preVal = v
          startRank = rank
          cachedUids += uid
          output
        } else {
          cachedUids += uid
          Iterator.empty
        }
      } ++ flush()
    }
    // Replace values in the input matrix by their ranks compared with values in the same column.
    // Note that shifting all ranks in a column by a constant value doesn't affect result.
    val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
      // sort by column index and then convert values to a vector
      Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
    }
    PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
  }
} 
Example 146
Source File: TestOutputStream.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.{IOException, ObjectInputStream}

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.{DStream, ForEachDStream}
import org.apache.spark.util.Utils

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


class TestOutputStream[T: ClassTag](parent: DStream[T],
    val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
  extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
    val collected = rdd.collect()
    output += collected
  }, false) {

  // This is to clear the output buffer every it is read from a checkpoint
  @throws(classOf[IOException])
  private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
    ois.defaultReadObject()
    output.clear()
  }
} 
Example 147
Source File: FlumeStreamSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
  var ssc: StreamingContext = null

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 148
Source File: JDBCRelation.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}


  def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
    if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))

    val numPartitions = partitioning.numPartitions
    val column = partitioning.column
    if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
    // Overflow and silliness can happen if you subtract then divide.
    // Here we get a little roundoff, but that's (hopefully) OK.
    val stride: Long = (partitioning.upperBound / numPartitions
                      - partitioning.lowerBound / numPartitions)
    var i: Int = 0
    var currentValue: Long = partitioning.lowerBound
    var ans = new ArrayBuffer[Partition]()
    while (i < numPartitions) {
      val lowerBound = if (i != 0) s"$column >= $currentValue" else null
      currentValue += stride
      val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
      val whereClause =
        if (upperBound == null) {
          lowerBound
        } else if (lowerBound == null) {
          upperBound
        } else {
          s"$lowerBound AND $upperBound"
        }
      ans += JDBCPartition(whereClause, i)
      i = i + 1
    }
    ans.toArray
  }
}

private[sql] case class JDBCRelation(
    url: String,
    table: String,
    parts: Array[Partition],
    properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedFilteredScan
  with InsertableRelation {

  override val needConversion: Boolean = false

  override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    JDBCRDD.scanTable(
      sqlContext.sparkContext,
      schema,
      url,
      properties,
      table,
      requiredColumns,
      filters,
      parts).asInstanceOf[RDD[Row]]
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    data.write
      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
      .jdbc(url, table, properties)
  }
} 
Example 149
Source File: KPLBasedKinesisTestUtils.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.kinesis

import java.nio.ByteBuffer

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
  override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
    if (!aggregate) {
      new SimpleDataGenerator(kinesisClient)
    } else {
      new KPLDataGenerator(regionName)
    }
  }
}


private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {

  private lazy val producer: KPLProducer = {
    val conf = new KinesisProducerConfiguration()
      .setRecordMaxBufferedTime(1000)
      .setMaxConnections(1)
      .setRegion(regionName)
      .setMetricsLevel("none")

    new KPLProducer(conf)
  }

  override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
    val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
    data.foreach { num =>
      val str = num.toString
      val data = ByteBuffer.wrap(str.getBytes())
      val future = producer.addUserRecord(streamName, str, data)
      val kinesisCallBack = new FutureCallback[UserRecordResult]() {
        override def onFailure(t: Throwable): Unit = {} // do nothing

        override def onSuccess(result: UserRecordResult): Unit = {
          val shardId = result.getShardId
          val seqNumber = result.getSequenceNumber()
          val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
            new ArrayBuffer[(Int, String)]())
          sentSeqNumbers += ((num, seqNumber))
        }
      }
      Futures.addCallback(future, kinesisCallBack)
    }
    producer.flushSync()
    shardIdToSeqNumbers.toMap
  }
} 
Example 150
Source File: UnionDStream.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkException
import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD

private[streaming]
class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
  extends DStream[T](parents.head.ssc) {

  require(parents.length > 0, "List of DStreams to union is empty")
  require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts")
  require(parents.map(_.slideDuration).distinct.size == 1,
    "Some of the DStreams have different slide durations")

  override def dependencies: List[DStream[_]] = parents.toList

  override def slideDuration: Duration = parents.head.slideDuration

  override def compute(validTime: Time): Option[RDD[T]] = {
    val rdds = new ArrayBuffer[RDD[T]]()
    parents.map(_.getOrCompute(validTime)).foreach {
      case Some(rdd) => rdds += rdd
      case None => throw new SparkException("Could not generate RDD from a parent for unifying at" +
        s" time $validTime")
    }
    if (rdds.size > 0) {
      Some(new UnionRDD(ssc.sc, rdds))
    } else {
      None
    }
  }
} 
Example 151
Source File: QueueInputDStream.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.dstream

import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable.{ArrayBuffer, Queue}
import scala.reflect.ClassTag

import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming.{Time, StreamingContext}

private[streaming]
class QueueInputDStream[T: ClassTag](
    ssc: StreamingContext,
    val queue: Queue[RDD[T]],
    oneAtATime: Boolean,
    defaultRDD: RDD[T]
  ) extends InputDStream[T](ssc) {

  override def start() { }

  override def stop() { }

  private def readObject(in: ObjectInputStream): Unit = {
    throw new NotSerializableException("queueStream doesn't support checkpointing. " +
      "Please don't use queueStream when checkpointing is enabled.")
  }

  private def writeObject(oos: ObjectOutputStream): Unit = {
    logWarning("queueStream doesn't support checkpointing")
  }

  override def compute(validTime: Time): Option[RDD[T]] = {
    val buffer = new ArrayBuffer[RDD[T]]()
    if (oneAtATime && queue.size > 0) {
      buffer += queue.dequeue()
    } else {
      buffer ++= queue.dequeueAll(_ => true)
    }
    if (buffer.size > 0) {
      if (oneAtATime) {
        Some(buffer.head)
      } else {
        Some(new UnionRDD(context.sc, buffer.toSeq))
      }
    } else if (defaultRDD != null) {
      Some(defaultRDD)
    } else {
      Some(ssc.sparkContext.emptyRDD)
    }
  }

} 
Example 152
Source File: LocalSparkCluster.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.deploy

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
import org.apache.spark.util.Utils


    for (workerNum <- 1 to numWorkers) {
      val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker,
        memoryPerWorker, masters, null, Some(workerNum), _conf)
      workerRpcEnvs += workerEnv
    }

    masters
  }

  def stop() {
    logInfo("Shutting down local Spark cluster.")
    // Stop the workers before the master so they don't get upset that it disconnected
    workerRpcEnvs.foreach(_.shutdown())
    masterRpcEnvs.foreach(_.shutdown())
    workerRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.foreach(_.awaitTermination())
    masterRpcEnvs.clear()
    workerRpcEnvs.clear()
  }
} 
Example 153
Source File: Schedulable.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode


private[spark] trait Schedulable {
  var parent: Pool
  // child queues
  def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
  def schedulingMode: SchedulingMode
  def weight: Int
  def minShare: Int
  def runningTasks: Int
  def priority: Int
  def stageId: Int
  def name: String

  def addSchedulable(schedulable: Schedulable): Unit
  def removeSchedulable(schedulable: Schedulable): Unit
  def getSchedulableByName(name: String): Schedulable
  def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit
  def checkSpeculatableTasks(): Boolean
  def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
} 
Example 154
Source File: ByteArrayChunkOutputStream.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util.io

import java.io.OutputStream

import scala.collection.mutable.ArrayBuffer



  private var position = chunkSize

  override def write(b: Int): Unit = {
    allocateNewChunkIfNeeded()
    chunks(lastChunkIndex)(position) = b.toByte
    position += 1
  }

  override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
    var written = 0
    while (written < len) {
      allocateNewChunkIfNeeded()
      val thisBatch = math.min(chunkSize - position, len - written)
      System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
      written += thisBatch
      position += thisBatch
    }
  }

  @inline
  private def allocateNewChunkIfNeeded(): Unit = {
    if (position == chunkSize) {
      chunks += new Array[Byte](chunkSize)
      lastChunkIndex += 1
      position = 0
    }
  }

  def toArrays: Array[Array[Byte]] = {
    if (lastChunkIndex == -1) {
      new Array[Array[Byte]](0)
    } else {
      // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
      // An alternative would have been returning an array of ByteBuffers, with the last buffer
      // bounded to only the last chunk's position. However, given our use case in Spark (to put
      // the chunks in block manager), only limiting the view bound of the buffer would still
      // require the block manager to store the whole chunk.
      val ret = new Array[Array[Byte]](chunks.size)
      for (i <- 0 until chunks.size - 1) {
        ret(i) = chunks(i)
      }
      if (position == chunkSize) {
        ret(lastChunkIndex) = chunks(lastChunkIndex)
      } else {
        ret(lastChunkIndex) = new Array[Byte](position)
        System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
      }
      ret
    }
  }
} 
Example 155
Source File: UnionRDD.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils


private[spark] class UnionPartition[T: ClassTag](
    idx: Int,
    @transient private val rdd: RDD[T],
    val parentRddIndex: Int,
    @transient private val parentRddPartitionIndex: Int)
  extends Partition {

  var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)

  def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition)

  override val index: Int = idx

  @throws(classOf[IOException])
  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
    // Update the reference to parent split at the time of task serialization
    parentPartition = rdd.partitions(parentRddPartitionIndex)
    oos.defaultWriteObject()
  }
}

@DeveloperApi
class UnionRDD[T: ClassTag](
    sc: SparkContext,
    var rdds: Seq[RDD[T]])
  extends RDD[T](sc, Nil) {  // Nil since we implement getDependencies

  override def getPartitions: Array[Partition] = {
    val array = new Array[Partition](rdds.map(_.partitions.length).sum)
    var pos = 0
    for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
      array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
      pos += 1
    }
    array
  }

  override def getDependencies: Seq[Dependency[_]] = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for (rdd <- rdds) {
      deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length)
      pos += rdd.partitions.length
    }
    deps
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

  override def getPreferredLocations(s: Partition): Seq[String] =
    s.asInstanceOf[UnionPartition[T]].preferredLocations()

  override def clearDependencies() {
    super.clearDependencies()
    rdds = null
  }
} 
Example 156
Source File: TaskContextImpl.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

private[spark] class TaskContextImpl(
    val stageId: Int,
    val partitionId: Int,
    override val taskAttemptId: Long,
    override val attemptNumber: Int,
    override val taskMemoryManager: TaskMemoryManager,
    @transient private val metricsSystem: MetricsSystem,
    internalAccumulators: Seq[Accumulator[Long]],
    val runningLocally: Boolean = false,
    val taskMetrics: TaskMetrics = TaskMetrics.empty)
  extends TaskContext
  with Logging {

  // For backwards-compatibility; this method is now deprecated as of 1.3.0.
  override def attemptId(): Long = taskAttemptId

  // List of callback functions to execute when the task completes.
  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]

  // Whether the corresponding task has been killed.
  @volatile private var interrupted: Boolean = false

  // Whether the task has completed.
  @volatile private var completed: Boolean = false

  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
    onCompleteCallbacks += listener
    this
  }

  override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f(context)
    }
    this
  }

  @deprecated("use addTaskCompletionListener", "1.1.0")
  override def addOnCompleteCallback(f: () => Unit) {
    onCompleteCallbacks += new TaskCompletionListener {
      override def onTaskCompletion(context: TaskContext): Unit = f()
    }
  }

  
  private[spark] def markInterrupted(): Unit = {
    interrupted = true
  }

  override def isCompleted(): Boolean = completed

  override def isRunningLocally(): Boolean = runningLocally

  override def isInterrupted(): Boolean = interrupted

  override def getMetricsSources(sourceName: String): Seq[Source] =
    metricsSystem.getSourcesByName(sourceName)

  @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]

  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
    accumulators(a.id) = a
  }

  private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
    accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
  }

  private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
    accumulators.mapValues(_.localValue).toMap
  }

  //private[spark]
  override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
    // Explicitly register internal accumulators here because these are
    // not captured in the task closure and are already deserialized
    internalAccumulators.foreach(registerAccumulator)
    internalAccumulators.map { a => (a.name.get, a) }.toMap
  }
} 
Example 157
Source File: LinkerdApi.scala    From asura   with MIT License 5 votes vote down vote up
package asura.app.api

import asura.app.api.model.Dtabs
import asura.app.api.model.Dtabs.DtabItem
import asura.common.exceptions.ErrorMessages.ErrorMessage
import asura.common.model.{ApiRes, ApiResError}
import asura.core.http.HttpEngine
import asura.core.{CoreConfig, ErrorMessages}
import asura.namerd.DtabEntry
import asura.namerd.api.v1.NamerdV1Api
import asura.play.api.BaseApi.OkApiRes
import javax.inject.{Inject, Singleton}
import org.pac4j.play.scala.SecurityComponents
import play.api.Configuration

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}


@Singleton
class LinkerdApi @Inject()(
                            implicit val exec: ExecutionContext,
                            val controllerComponents: SecurityComponents,
                            config: Configuration,
                          ) extends BaseApi {

  val srcPrefix = "/svc/"
  val dstPrefix = "/$/inet/"

  def getProxyServers() = Action { implicit req =>
    if (CoreConfig.linkerdConfig.enabled) {
      OkApiRes(ApiRes(data = CoreConfig.linkerdConfig.servers))
    } else {
      OkApiRes(ApiResError(getI18nMessage(ErrorMessages.error_ProxyDisabled.name)))
    }
  }

  def getHttp(group: String, project: String, server: String) = Action.async { implicit req =>
    if (CoreConfig.linkerdConfig.enabled) {
      val proxyServer = CoreConfig.linkerdConfig.servers.find(_.tag.equals(server)).get
      NamerdV1Api.getNamespaceDtabs(proxyServer.namerd, proxyServer.httpNs)(HttpEngine.http).map(dtabs => {
        val items = ArrayBuffer[DtabItem]()
        dtabs.foreach(entry => {
          val pStrs = entry.prefix.split("/")
          val dStrs = entry.dst.split("/")
          if (pStrs.length == 5 && dStrs.length == 5) {
            items += DtabItem(
              group = pStrs(2),
              project = pStrs(3),
              namespace = pStrs(4),
              host = dStrs(3),
              port = dStrs(4),
              owned = group == pStrs(2) && project == pStrs(3)
            )
          }
        })
        toActionResultFromAny(items)
      })
    } else {
      Future.successful(OkApiRes(ApiResError(getI18nMessage(ErrorMessages.error_ProxyDisabled.name))))
    }
  }

  def putHttp(group: String, project: String, server: String) = Action(parse.byteString).async { implicit req =>
    if (CoreConfig.linkerdConfig.enabled) {
      val proxyServer = CoreConfig.linkerdConfig.servers.find(_.tag.equals(server)).get
      val dtabs = req.bodyAs(classOf[Dtabs])
      if (null != dtabs && null != dtabs.dtabs && dtabs.dtabs.nonEmpty) {
        var error: ErrorMessage = null
        val entries = ArrayBuffer[DtabEntry]()
        for (i <- 0 until dtabs.dtabs.length if null == error) {
          val item = dtabs.dtabs(i)
          error = item.isValid()
          entries += DtabEntry(
            s"${srcPrefix}${item.group}/${item.project}/${item.namespace}",
            s"${dstPrefix}${item.host}/${item.port}"
          )
        }
        if (null == error) {
          NamerdV1Api.updateNamespaceDtabs(proxyServer.namerd, proxyServer.httpNs, entries)(HttpEngine.http).toOkResult
        } else {
          error.toFutureFail
        }
      } else {
        Future.successful(OkApiRes(ApiRes()))
      }
    } else {
      Future.successful(OkApiRes(ApiResError(getI18nMessage(ErrorMessages.error_ProxyDisabled.name))))
    }
  }
} 
Example 158
Source File: InterfaceMethodParamsActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.dubbo.actor

import akka.actor.{ActorRef, Props, Status}
import akka.pattern.pipe
import akka.util.ByteString
import asura.common.actor.BaseActor
import asura.common.util.LogUtils
import asura.dubbo.DubboConfig
import asura.dubbo.actor.GenericServiceInvokerActor.GetInterfaceMethodParams
import asura.dubbo.model.InterfaceMethodParams
import asura.dubbo.model.InterfaceMethodParams.MethodSignature

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}

class InterfaceMethodParamsActor(invoker: ActorRef, msg: GetInterfaceMethodParams) extends BaseActor {

  implicit val ec: ExecutionContext = context.dispatcher
  private val telnet: ActorRef = context.actorOf(TelnetClientActor.props(msg.address, if (msg.port > 0) msg.port else DubboConfig.DEFAULT_PORT, self))

  override def receive: Receive = {
    case telnetData: ByteString =>
      val utf8String = telnetData.utf8String
      if (utf8String.contains(TelnetClientActor.MSG_CONNECT_TO)) {
        log.debug(utf8String)
        if (utf8String.contains(TelnetClientActor.MSG_SUCCESS)) {
          telnet ! ByteString(s"ls -l ${msg.ref}\r\n")
        } else if (utf8String.contains(TelnetClientActor.MSG_FAIL)) {
          Future.failed(new RuntimeException(s"Remote connection to ${msg.address}:${msg.port} failed")) pipeTo invoker
          telnet ! TelnetClientActor.CMD_CLOSE
          context stop self
        } else {
          Future.failed(new RuntimeException(s"Unknown response ${utf8String}")) pipeTo invoker
          telnet ! TelnetClientActor.CMD_CLOSE
          context stop self
        }
      } else if (utf8String.contains("(") && utf8String.contains(")")) {
        getInterfaceMethodParams(msg.ref, utf8String) pipeTo invoker
        telnet ! TelnetClientActor.CMD_CLOSE
      } else {
        Future.failed(new RuntimeException(s"Unknown response: ${utf8String}")) pipeTo invoker
        telnet ! TelnetClientActor.CMD_CLOSE
        context stop self
      }
    case Status.Failure(t) =>
      val stackTrace = LogUtils.stackTraceToString(t)
      log.warning(stackTrace)
      context stop self
  }

  def getInterfaceMethodParams(ref: String, content: String): Future[InterfaceMethodParams] = {
    Future.successful {
      val methods = ArrayBuffer[MethodSignature]()
      content.split("\r\n")
        .filter(!_.startsWith(DubboConfig.DEFAULT_PROMPT))
        .map(signature => {
          val splits = signature.split(" ")
          if (splits.length == 2) {
            val ret = splits(0)
            val secondPart = splits(1)
            val idx = secondPart.indexOf("(")
            val method = secondPart.substring(0, idx)
            val params = secondPart.substring(idx + 1, secondPart.length - 1).split(",")
            methods += (MethodSignature(ret, method, params))
          }
        })
      InterfaceMethodParams(ref, methods)
    }
  }

  override def postStop(): Unit = log.debug(s"${self.path} stopped")
}

object InterfaceMethodParamsActor {
  def props(invoker: ActorRef, msg: GetInterfaceMethodParams) = {
    Props(new InterfaceMethodParamsActor(invoker, msg))
  }
} 
Example 159
Source File: JobReportDataItemSaveActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.job.actor

import akka.actor.{Props, Status}
import asura.common.actor.BaseActor
import asura.common.util.LogUtils
import asura.core.actor.messages.Flush
import asura.core.es.model.JobReportDataItem
import asura.core.es.service.JobReportDataItemService
import asura.core.job.actor.JobReportDataItemSaveActor.SaveReportDataHttpItemMessage

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

class JobReportDataItemSaveActor(dayIndexSuffix: String) extends BaseActor {

  val messages = ArrayBuffer[SaveReportDataHttpItemMessage]()

  override def receive: Receive = {
    case m: SaveReportDataHttpItemMessage =>
      messages += m
      if (messages.length >= 10) {
        insert()
      }
      context.system.scheduler.scheduleOnce(2 seconds) {
        self ! Flush
      }(context.system.dispatcher)
    case Flush =>
      insert()
    case Status.Failure(t) =>
      log.warning(LogUtils.stackTraceToString(t))
  }

  override def preStart(): Unit = {
  }

  override def postStop(): Unit = {
    insert()
    log.debug(s"${self.path} is stopped")
  }

  private def insert(): Unit = {
    if (messages.length > 0) {
      log.debug(s"${messages.length} items is saving...")
      JobReportDataItemService.index(messages, dayIndexSuffix)
      messages.clear()
    }
  }
}

object JobReportDataItemSaveActor {

  def props(dayIndexSuffix: String) = Props(new JobReportDataItemSaveActor(dayIndexSuffix))

  case class SaveReportDataHttpItemMessage(id: String, dataItem: JobReportDataItem)

} 
Example 160
Source File: JobStatusActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.job.actor

import akka.actor.Status.Failure
import akka.actor.{ActorRef, Props}
import asura.common.actor._
import asura.common.model.Pagination
import asura.core.model.QueryJob
import asura.core.es.service.JobService
import asura.core.job.actor.JobStatusMonitorActor.JobStatusOperationMessage
import asura.core.job.eventbus.JobStatusBus.JobStatusNotificationMessage
import asura.core.job.{JobListItem, JobStates}
import asura.core.redis.RedisJobState
import asura.core.util.JacksonSupport
import com.typesafe.scalalogging.Logger

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class JobStatusActor() extends BaseActor {

  var query: QueryJob = null
  val watchIds = mutable.HashSet[String]()

  override def receive: Receive = {
    case SenderMessage(sender) =>
      context.become(query(sender))
  }

  def query(outSender: ActorRef): Receive = {
    case query: QueryJob =>
      this.query = query
      JobService.queryJob(query).map(esResponse =>
        if (esResponse.isSuccess) {
          val items = ArrayBuffer[JobListItem]()
          val jobsTable = mutable.HashMap[String, JobListItem]()
          val hits = esResponse.result.hits
          watchIds.clear()
          hits.hits.foreach(hit => {
            val jobId = hit.id
            watchIds.add(jobId)
            jobsTable += (jobId -> {
              val item = JacksonSupport.parse(hit.sourceAsString, classOf[JobListItem])
              item.state = JobStates.UNKNOWN
              items += item
              item._id = jobId
              item
            })
          })
          if (watchIds.nonEmpty) {
            RedisJobState.getJobState(watchIds.toSet).onComplete {
              case util.Success(statesMap) =>
                statesMap.forEach((jobKey, state) => jobsTable(jobKey).state = state)
                outSender ! ListActorEvent(Map("total" -> hits.total, "list" -> items))
              case util.Failure(_) =>
                outSender ! ListActorEvent(Map("total" -> hits.total, "list" -> items))
            }(context.system.dispatcher)
          } else {
            outSender ! ListActorEvent(Map("total" -> 0, "list" -> Nil))
          }
        } else {
          outSender ! ErrorActorEvent(esResponse.error.reason)
        })(context.system.dispatcher)
    case JobStatusNotificationMessage(_, operator, scheduler, group, name, data) =>
      if (watchIds.contains(name)) {
        outSender ! ItemActorEvent(JobStatusOperationMessage(operator, scheduler, group, name, data))
      }
    case eventMessage: ActorEvent =>
      outSender ! eventMessage
    case Failure(t) =>
      outSender ! ErrorActorEvent(t.getMessage)
  }

  override def postStop(): Unit = {
    import JobStatusActor.logger
    logger.debug(s"JobStatus for ${query} stopped")
  }
}

object JobStatusActor {

  val logger = Logger(classOf[JobStatusActor])

  def props() = Props(new JobStatusActor())

  case class JobQueryMessage(scheduler: String = null, group: String = null, text: String = null) extends Pagination

} 
Example 161
Source File: HeaderUtils.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.http

import akka.http.scaladsl.model.HttpHeader.ParsingResult.{Error, Ok}
import akka.http.scaladsl.model.headers.{Cookie, RawHeader}
import akka.http.scaladsl.model.{ErrorInfo, HttpHeader}
import asura.common.util.StringUtils
import asura.core.es.model.{Environment, HttpCaseRequest}
import asura.core.runtime.RuntimeContext
import asura.core.{CoreConfig, ErrorMessages}
import com.typesafe.scalalogging.Logger

import scala.collection.immutable
import scala.collection.mutable.ArrayBuffer

object HeaderUtils {

  val logger = Logger("HeaderUtils")

  def toHeaders(cs: HttpCaseRequest, context: RuntimeContext): immutable.Seq[HttpHeader] = {
    val headers = ArrayBuffer[HttpHeader]()
    val request = cs.request
    val env = if (null != context.options) context.options.getUsedEnv() else null
    if (null != request) {
      val headerSeq = request.header
      if (null != headerSeq) {
        for (h <- headerSeq if (h.enabled && StringUtils.isNotEmpty(h.key))) {
          HttpHeader.parse(h.key, context.renderSingleMacroAsString(h.value)) match {
            case Ok(header: HttpHeader, errors: List[ErrorInfo]) =>
              if (errors.nonEmpty) logger.warn(errors.mkString(","))
              headers += header
            case Error(error: ErrorInfo) =>
              logger.warn(error.detail)
          }
        }
      }
      val cookieSeq = request.cookie
      if (null != cookieSeq) {
        val cookies = ArrayBuffer[(String, String)]()
        for (c <- cookieSeq if (c.enabled && StringUtils.isNotEmpty(c.key))) {
          cookies += ((c.key, context.renderSingleMacroAsString(c.value)))
        }
        if (cookies.nonEmpty) headers += Cookie(cookies: _*)
      }
    }
    if (null != env && null != env.headers && env.headers.nonEmpty) {
      for (h <- env.headers if (h.enabled && StringUtils.isNotEmpty(h.key))) {
        HttpHeader.parse(h.key, context.renderSingleMacroAsString(h.value)) match {
          case Ok(header: HttpHeader, errors: List[ErrorInfo]) =>
            if (errors.nonEmpty) logger.warn(errors.mkString(","))
            headers += header
          case Error(error: ErrorInfo) =>
            logger.warn(error.detail)
        }
      }
    }
    if (null != env && env.enableProxy) {
      val headerIdentifier = validateProxyVariables(env)
      val dst = StringBuilder.newBuilder
      dst.append("/").append(cs.group).append("/").append(cs.project).append("/").append(env.namespace)
      headers += RawHeader(headerIdentifier, dst.toString)
    }
    headers.toList
  }

  def validateProxyVariables(env: Environment): String = {
    if (!CoreConfig.linkerdConfig.enabled) {
      throw ErrorMessages.error_ProxyDisabled.toException
    }
    if (StringUtils.isEmpty(env.namespace)) {
      throw ErrorMessages.error_EmptyNamespace.toException
    }
    if (StringUtils.isEmpty(env.server)) {
      throw ErrorMessages.error_EmptyProxyServer.toException
    }
    val proxyServerOpt = CoreConfig.linkerdConfig.servers.find(_.tag.equals(env.server))
    if (proxyServerOpt.isEmpty && StringUtils.isEmpty(proxyServerOpt.get.headerIdentifier)) {
      throw ErrorMessages.error_InvalidProxyConfig.toException
    } else {
      proxyServerOpt.get.headerIdentifier
    }
  }

  def isApplicationJson(header: HttpHeader): Boolean = {
    if (header.lowercaseName().equals("content-type")) {
      header.value().contains(HttpContentTypes.JSON)
    } else {
      false
    }
  }
} 
Example 162
Source File: Or.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.assertion

import asura.core.concurrent.ExecutionContextManager.cachedExecutor
import asura.core.assertion.engine.{AssertResult, AssertionContext, FailAssertResult, Statistic}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

object Or extends Assertion {

  override val name: String = Assertions.OR

  override def assert(actual: Any, expect: Any): Future[AssertResult] = {
    apply(actual, expect)
  }

  def apply(actual: Any, except: Any): Future[AssertResult] = {
    val result = AssertResult(
      isSuccessful = false,
      msg = AssertResult.MSG_FAILED
    )
    val subResults = ArrayBuffer[mutable.Map[String, Any]]()
    result.subResult = subResults
    except match {
      case assertions: Seq[_] =>
        if (assertions.nonEmpty) {
          val assertionResults = assertions.map(assertion => {
            val subStatis = Statistic()
            val assertionMap = assertion.asInstanceOf[Map[String, Any]]
            val contextMap = actual.asInstanceOf[Object]
            AssertionContext.eval(assertionMap, contextMap, subStatis).map((subStatis, _))
          })
          Future.sequence(assertionResults).map(subStatisResults => {
            val subResults = ArrayBuffer[java.util.Map[String, Any]]()
            result.subResult = subResults
            subStatisResults.foreach(subStatisResult => {
              val (subStatis, subResult) = subStatisResult
              subResults += subResult
              result.pass(subStatis.passed)
              result.fail(subStatis.failed)
              if (subStatis.isSuccessful) {
                result.isSuccessful = true
                result.msg = AssertResult.MSG_PASSED
              }
            })
            result
          })
        } else {
          Future.successful(null)
        }
      case _ =>
        Future.successful(FailAssertResult(1, AssertResult.msgIncomparableTargetType(except)))
    }
  }
} 
Example 163
Source File: And.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.assertion

import asura.core.concurrent.ExecutionContextManager.cachedExecutor
import asura.core.assertion.engine.{AssertResult, AssertionContext, FailAssertResult, Statistic}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

object And extends Assertion {

  override val name: String = Assertions.AND

  override def assert(actual: Any, expect: Any): Future[AssertResult] = {
    apply(actual, expect)
  }

  def apply(actual: Any, expect: Any): Future[AssertResult] = {
    val result = AssertResult(
      isSuccessful = true,
      msg = AssertResult.MSG_PASSED
    )
    expect match {
      case assertions: Seq[_] =>
        if (assertions.nonEmpty) {
          val assertionResults = assertions.map(assertion => {
            val subStatis = Statistic()
            val assertionMap = assertion.asInstanceOf[Map[String, Any]]
            val contextMap = actual.asInstanceOf[Object]
            AssertionContext.eval(assertionMap, contextMap, subStatis).map((subStatis, _))
          })
          Future.sequence(assertionResults).map(subStatisResults => {
            val subResults = ArrayBuffer[java.util.Map[String, Any]]()
            result.subResult = subResults
            subStatisResults.foreach(subStatisResult => {
              val (subStatis, subResult) = subStatisResult
              subResults += subResult
              result.pass(subStatis.passed)
              result.fail(subStatis.failed)
              if (!subStatis.isSuccessful) {
                result.isSuccessful = false
                result.msg = AssertResult.MSG_FAILED
              }
            })
            result
          })
        } else {
          Future.successful(null)
        }
      case _ =>
        Future.successful(FailAssertResult(1, AssertResult.msgIncomparableTargetType(expect)))
    }
  }
} 
Example 164
Source File: TriggerEventsSaveActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.actor

import akka.actor.{Props, Status}
import asura.common.actor.BaseActor
import asura.common.util.LogUtils
import asura.core.actor.messages.Flush
import asura.core.es.model.TriggerEventLog
import asura.core.es.service.TriggerEventLogService

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

class TriggerEventsSaveActor extends BaseActor {

  val logs = ArrayBuffer[TriggerEventLog]()

  override def receive: Receive = {
    case m: TriggerEventLog =>
      logs += m
      if (logs.length >= 20) {
        insert()
      }
      context.system.scheduler.scheduleOnce(2 seconds) {
        self ! Flush
      }(context.system.dispatcher)
    case Flush =>
      insert()
    case Status.Failure(t) =>
      log.warning(LogUtils.stackTraceToString(t))
  }

  override def preStart(): Unit = {
  }

  override def postStop(): Unit = {
    insert()
    log.debug(s"${self.path} is stopped")
  }

  private def insert(): Unit = {
    if (logs.length > 0) {
      log.debug(s"${logs.length} trigger events is saving...")
      TriggerEventLogService.index(logs)
      logs.clear()
    }
  }
}

object TriggerEventsSaveActor {
  def props() = Props(new TriggerEventsSaveActor())
} 
Example 165
Source File: ActivitySaveActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.actor

import akka.actor.{Props, Status}
import asura.common.actor.BaseActor
import asura.common.util.LogUtils
import asura.core.actor.messages.Flush
import asura.core.es.model.Activity
import asura.core.es.service.ActivityService

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

class ActivitySaveActor extends BaseActor {

  val activities = ArrayBuffer[Activity]()

  override def receive: Receive = {
    case m: Activity =>
      activities += m
      if (activities.length >= 20) {
        insert()
      }
      context.system.scheduler.scheduleOnce(2 seconds) {
        self ! Flush
      }(context.system.dispatcher)
    case Flush =>
      insert()
    case Status.Failure(t) =>
      log.warning(LogUtils.stackTraceToString(t))
  }

  override def preStart(): Unit = {
  }

  override def postStop(): Unit = {
    insert()
    log.debug(s"${self.path} is stopped")
  }

  private def insert(): Unit = {
    if (activities.length > 0) {
      log.debug(s"${activities.length} activities is saving...")
      ActivityService.index(activities)
      activities.clear()
    }
  }
}

object ActivitySaveActor {

  def props() = Props(new ActivitySaveActor())

} 
Example 166
Source File: HttpResponse.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.model

import asura.core.http.HttpContentTypes
import io.swagger.models.properties.RefProperty
import io.swagger.models.{ModelImpl, Response, Swagger}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

case class HttpResponse(
                         description: String,
                         headers: Seq[ParameterSchema],
                         contentType: String,
                         schema: JsonSchema
                       ) {

}

object HttpResponse {

  
  def toResponses(openApi: Swagger, responses: mutable.Map[String, Response]): Map[String, HttpResponse] = {
    val definitions = openApi.getDefinitions
    val responseMap = mutable.Map[String, HttpResponse]()
    for ((code, res) <- responses) {
      val schema: JsonSchema = res.getSchema match {
        case p: RefProperty =>
          definitions.get(p.getSimpleRef) match {
            case model: ModelImpl =>
              JsonSchema.toJsonSchema(model)
            case _ =>
              null
          }
        case _ =>
          null
      }
      val headers = ArrayBuffer[ParameterSchema]()
      if (null != res.getHeaders) {
        res.getHeaders.forEach((name, property) => {
          headers += (ParameterSchema(
            name = name,
            description = property.getDescription,
            `type` = SchemaObject.translateOpenApiType(property.getType, property.getFormat)
          ))
        })
      }
      responseMap += (code -> HttpResponse(
        description = res.getDescription,
        headers = headers.toList,
        contentType = HttpContentTypes.JSON,
        schema = schema
      ))
    }
    responseMap.toMap
  }
} 
Example 167
Source File: RecommendService.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.service

import asura.common.util.StringUtils
import asura.core.concurrent.ExecutionContextManager.sysGlobal
import asura.core.es.model.{FieldKeys, Project}
import asura.core.model.RecommendProject

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

object RecommendService {

  def getRecommendProjects(user: String, wd: String, discover: Boolean): Future[RecommendProjects] = {
    val futureTuple = for {
      my <- getRecommendProject(user, true, wd, 20, Nil)
      other <- if (discover) {
        getRecommendProject(user, false, null, 5, my.map(p => ((p.group, p.project))))
      } else {
        Future.successful(Nil)
      }
    } yield (my, other)
    futureTuple.map(tuple => RecommendProjects(tuple._1, tuple._2))
  }

  def getRecommendProject(user: String, me: Boolean, wd: String, size: Int, excludeGPs: Seq[(String, String)]): Future[Seq[RecommendProject]] = {
    val items = ArrayBuffer[RecommendProject]()
    ActivityService.recentProjects(user, me, wd, size).flatMap(aggItems => {
      if (aggItems.nonEmpty) {
        val map = mutable.Map[String, RecommendProject]()
        aggItems.foreach(item => {
          if (StringUtils.isNotEmpty(item.id)) {
            val gp = item.id.split("/")
            if (gp.length == 2) {
              val project = RecommendProject(gp(0), gp(1), item.count)
              items += project
              map += (Project.generateDocId(gp(0), gp(1)) -> project)
            }
          }
        })
        ProjectService.getByIds(map.keys.toSeq, Seq(FieldKeys.FIELD_SUMMARY)).map(idMap => {
          idMap.foreach(tuple => {
            val id = tuple._1
            val project = tuple._2
            map(id).summary = project.summary
          })
          items
        })
      } else {
        Future.successful(items)
      }
    })
  }

  case class RecommendProjects(
                                my: Seq[RecommendProject],
                                others: Seq[RecommendProject]
                              )

} 
Example 168
Source File: HomeService.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.service

import asura.common.util.StringUtils
import asura.core.concurrent.ExecutionContextManager.sysGlobal
import asura.core.es.EsClient
import asura.core.es.model._
import asura.core.model.QueryHome
import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.requests.searches.queries.Query

import scala.collection.mutable.ArrayBuffer

object HomeService extends CommonService {

  val includeFields = Seq(
    FieldKeys.FIELD_GROUP,
    FieldKeys.FIELD_PROJECT,
    FieldKeys.FIELD_ID,
    FieldKeys.FIELD_AVATAR,
    FieldKeys.FIELD_SUMMARY,
    FieldKeys.FIELD_DESCRIPTION,
    FieldKeys.FIELD_OBJECT_REQUEST_URLPATH
  )

  def queryDoc(query: QueryHome) = {
    EsClient.esClient.execute {
      val esQueries = ArrayBuffer[Query]()
      if (StringUtils.isNotEmpty(query.text)) esQueries += matchQuery(FieldKeys.FIELD__TEXT, query.text)
      search(Group.Index, Project.Index, HttpCaseRequest.Index, DubboRequest.Index, SqlRequest.Index,
        Environment.Index, Scenario.Index, Job.Index)
        .query(boolQuery().must(esQueries))
        .sourceInclude(includeFields)
        .size(3)
    }
  }
} 
Example 169
Source File: TriggerEventLogService.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.service

import asura.common.model.ApiMsg
import asura.common.util.{FutureUtils, StringUtils}
import asura.core.concurrent.ExecutionContextManager.sysGlobal
import asura.core.es.EsClient
import asura.core.es.model._
import asura.core.model.QueryCiEvents
import asura.core.util.JacksonSupport.jacksonJsonIndexable
import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.requests.searches.queries.Query

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

object TriggerEventLogService extends CommonService with BaseAggregationService {

  def index(items: Seq[TriggerEventLog]): Future[BulkDocResponse] = {
    if (null == items && items.isEmpty) {
      FutureUtils.illegalArgs(ApiMsg.INVALID_REQUEST_BODY)
    } else {
      EsClient.esClient.execute {
        bulk(
          items.map(item => indexInto(TriggerEventLog.Index).doc(item))
        )
      }.map(toBulkDocResponse(_))
    }
  }

  def queryEvents(query: QueryCiEvents) = {
    val esQueries = ArrayBuffer[Query]()
    if (StringUtils.isNotEmpty(query.group)) esQueries += termQuery(FieldKeys.FIELD_GROUP, query.group)
    if (StringUtils.isNotEmpty(query.project)) esQueries += termQuery(FieldKeys.FIELD_PROJECT, query.project)
    if (StringUtils.isNotEmpty(query.env)) esQueries += termQuery(FieldKeys.FIELD_ENV, query.env)
    if (StringUtils.isNotEmpty(query.`type`)) esQueries += termQuery(FieldKeys.FIELD_TYPE, query.`type`)
    if (StringUtils.isNotEmpty(query.service)) esQueries += termQuery(FieldKeys.FIELD_SERVICE, query.service)
    EsClient.esClient.execute {
      search(TriggerEventLog.Index).query(boolQuery().must(esQueries))
        .from(query.pageFrom)
        .size(query.pageSize)
        .sortByFieldDesc(FieldKeys.FIELD_CREATED_AT)
    }
  }
} 
Example 170
Source File: IndexService.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.es.service

import asura.common.util.StringUtils
import asura.core.concurrent.ExecutionContextManager.sysGlobal
import asura.core.es.EsClient
import asura.core.es.model.{FieldKeys, IndexSetting, JobReportDataItem, RestApiOnlineLog}
import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.Indexes
import com.sksamuel.elastic4s.requests.delete.DeleteByQueryRequest
import com.sksamuel.elastic4s.requests.searches.queries.Query
import com.typesafe.scalalogging.Logger

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

object IndexService extends CommonService {

  val logger = Logger("IndexService")

  
  def initCheck(idx: IndexSetting): Boolean = {
    val cli = EsClient.esClient
    val res = cli.execute(indexExists(idx.Index)).await
    if (res.isSuccess) {
      if (res.result.exists) {
        true
      } else {
        val res2 = cli.execute {
          createIndex(idx.Index)
            .shards(idx.shards)
            .replicas(idx.replicas)
            .mapping(idx.mappings)
        }.await
        if (res2.isSuccess) {
          true
        } else {
          logger.error(res2.error.reason)
          false
        }
      }
    } else {
      logger.error(res.error.reason)
      false
    }
  }

  def checkTemplate(): Boolean = {
    checkIndexTemplate(JobReportDataItem).await && checkIndexTemplate(RestApiOnlineLog).await
  }

  def checkIndexTemplate(idxSetting: IndexSetting): Future[Boolean] = {
    logger.info(s"check es template ${idxSetting.Index}")
    val cli = EsClient.esClient
    cli.execute {
      getIndexTemplate(idxSetting.Index)
    }.map { res =>
      if (res.status != 404) true else false
    }.recover {
      case _ => false
    }.flatMap(hasTpl => {
      if (!hasTpl) {
        cli.execute {
          createIndexTemplate(idxSetting.Index, s"${idxSetting.Index}-*")
            .settings(Map(
              "number_of_replicas" -> idxSetting.replicas,
              "number_of_shards" -> idxSetting.shards
            ))
            .mappings(idxSetting.mappings)
        }.map(tplIndex => {
          if (tplIndex.result.acknowledged) true else false
        })
      } else {
        Future.successful(true)
      }
    })
  }

  def delIndex(indices: Seq[String]) = {
    EsClient.esClient.execute {
      deleteIndex(indices)
    }.map(toDeleteIndexResponse(_))
  }

  def deleteByGroupOrProject(indices: Seq[String], group: String, project: String) = {
    val esQueries = ArrayBuffer[Query]()
    if (StringUtils.isNotEmpty(group)) esQueries += termQuery(FieldKeys.FIELD_GROUP, group)
    if (StringUtils.isNotEmpty(project)) esQueries += termQuery(FieldKeys.FIELD_PROJECT, project)
    EsClient.esClient.execute {
      DeleteByQueryRequest(
        Indexes(indices),
        boolQuery().must(esQueries)
      ).refreshImmediately
    }.map(toDeleteByQueryResponse(_))
  }
} 
Example 171
Source File: ScalapropsRunner.scala    From scalaprops   with MIT License 5 votes vote down vote up
package scalaprops

import sbt.testing._
import scala.collection.mutable.ArrayBuffer

object ScalapropsRunner {

  
  def testFieldNames(clazz: Class[_]): Array[String] =
    Scalaprops.testFieldNames(clazz)

  private[scalaprops] def getTestObject(
    fingerprint: Fingerprint,
    testClassName: String,
    testClassLoader: ClassLoader
  ): Scalaprops = {
    ???
  }

  private[scalaprops] def findTests(
    fingerprint: Fingerprint,
    testClassName: String,
    testClassLoader: ClassLoader,
    only: List[String],
    logger: Logger
  ): Properties[_] = {
    ???
  }
}

final class ScalapropsRunner(
  override val args: Array[String],
  override val remoteArgs: Array[String],
  testClassLoader: ClassLoader
) extends Runner {

  private[this] val results = ArrayBuffer.empty[TestResult]
  private[this] val arguments = Arguments.parse(args.toList)

  private[this] val taskdef2task: TaskDef => sbt.testing.Task = { taskdef =>
    new ScalapropsTaskImpl(taskdef, testClassLoader, args, arguments, results, TestStatus())
  }

  override def tasks(taskDefs: Array[TaskDef]) = taskDefs.map(taskdef2task)

  override def done() = {
    val result = TestResult.formatResults(results, arguments.showDuration)
    println(result)
    result
  }
} 
Example 172
Source File: SolrTableFactory.scala    From solr-sql   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.calcite.adapter.solr

import scala.annotation.migration
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.schema.TableFactory
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.log4j.Logger
import org.apache.solr.client.solrj.SolrClient
import org.apache.solr.client.solrj.impl.CloudSolrClient
import org.apache.solr.client.solrj.impl.HttpSolrClient


trait SolrClientFactory {
	def getClient(): SolrClient;
}

class SolrTableFactory extends TableFactory[SolrTable] {
	val logger = Logger.getLogger(this.getClass);

	override def create(parentSchema: SchemaPlus, name: String,
		operands: java.util.Map[String, Object], rowTypw: RelDataType): SolrTable = {

		val args = JavaConversions.mapAsScalaMap(operands).toMap.map(x ⇒ (x._1, x._2.toString()));
		//columns="title string, url string, content_length int"
		SolrTableConf.argumentsRequired(args, SolrTableConf.COULMNS);

		val columns: Map[String, SqlTypeName] = SolrTableConf.parseColumns(args, SolrTableConf.COULMNS);
		logger.debug(s"defined columns: $columns");

		//columnMapping="title->solr_field_title, url->solr_field_url"
		val definedColumnMapping = SolrTableConf.parseMap(args, SolrTableConf.COLUMN_MAPPING);
		logger.debug(s"defined column mapping: $definedColumnMapping");

		val filledColumnMapping = columns.map(x ⇒ (x._1, definedColumnMapping.getOrElse(x._1, x._1)));

		//options="pageSize:20,solrZkHosts=10.0.71.14:2181,10.0.71.17:2181,10.0.71.38:2181"
		val options = args;

		//a singleton of solr client
		val solrClientFactory = new SolrClientFactory {
			val clients = ArrayBuffer[SolrClient]();
			override def getClient = {
				if (clients.isEmpty) {
					if (options.keySet.contains(SolrTableConf.SOLR_ZK_HOSTS)) {
						val solrZkHosts = options(SolrTableConf.SOLR_ZK_HOSTS);
						logger.debug(s"connecting to solr cloud via zookeeper servers: $solrZkHosts");
						val csc = new CloudSolrClient(solrZkHosts);
						csc.setDefaultCollection(options("solrCollection"));
						clients += csc;
					}
					else {
						SolrTableConf.argumentsRequired(args, SolrTableConf.SOLR_ZK_HOSTS, SolrTableConf.SOLR_SERVER_URL);

						val solrServerURL = options(SolrTableConf.SOLR_SERVER_URL);
						logger.debug(s"connecting to solr server: $solrServerURL");
						clients += new HttpSolrClient(solrServerURL);
					}
				}

				clients(0);
			}
		}

		new SolrTable(solrClientFactory, columns, filledColumnMapping, options);
	}
} 
Example 173
Source File: TSQR.scala    From SparkAndMPIFactorizations   with MIT License 5 votes vote down vote up
package edu.berkeley.cs.amplab.mlmatrix

import java.util.concurrent.ThreadLocalRandom
import scala.collection.mutable.ArrayBuffer

import breeze.linalg._

import edu.berkeley.cs.amplab.mlmatrix.util.QRUtils
import edu.berkeley.cs.amplab.mlmatrix.util.Utils

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.Accumulator
import org.apache.spark.SparkContext._

import java.util.Calendar
import java.text.SimpleDateFormat

class modifiedTSQR extends Serializable {

    def report(message: String, verbose: Boolean = true) = {
        val now = Calendar.getInstance().getTime()
        val formatter = new SimpleDateFormat("H:m:s")
        if (verbose) {
            println("STATUS REPORT (" + formatter.format(now) + "): " + message)
        }
    }

  
  private def reduceQR(
      acc: Accumulator[Double],
      a: Tuple2[DenseVector[Double], DenseMatrix[Double]],
      b: Tuple2[DenseVector[Double], DenseMatrix[Double]]): Tuple2[DenseVector[Double], DenseMatrix[Double]] = {
    val begin = System.nanoTime
    val outmat = QRUtils.qrR(DenseMatrix.vertcat(a._2, b._2), false)
    val outcolnorms = a._1 + b._1
    acc += ((System.nanoTime - begin) / 1e6)
    (outcolnorms, outmat)
  }

} 
Example 174
Source File: ParallelizedWithLocalityRDD.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.cloudera

import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.rdd.{ParallelCollectionPartition, RDD}


  def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    if (numSlices < 1) {
      throw new IllegalArgumentException(
        "Positive number of partitions required")
    }

    // Sequences need to be sliced at the same set of index positions for operations
    // like RDD.zip() to behave as expected
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
      (0 until numSlices).iterator.map { i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }

    seq match {
      case r: Range =>
        positions(r.length, numSlices).zipWithIndex
          .map { case ((start, end), index) =>
            // If the range is inclusive, use inclusive range for the last slice
            if (r.isInclusive && index == numSlices - 1) {
              new Range.Inclusive(r.start + start * r.step, r.end, r.step)
            }
            else {
              new Range(r.start + start * r.step,
                r.start + end * r.step,
                r.step)
            }
          }.toSeq.asInstanceOf[Seq[Seq[T]]]
      case nr: NumericRange[T] =>
        // For ranges of Long, Double, BigInteger, etc
        val slices = new ArrayBuffer[Seq[T]](numSlices)
        var r = nr
        for ((start, end) <- positions(nr.length, numSlices)) {
          val sliceSize = end - start
          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
          r = r.drop(sliceSize)
        }
        slices
      case _ =>
        val array = seq.toArray // To prevent O(n^2) operations for List etc
        positions(array.length, numSlices).map { case (start, end) =>
          array.slice(start, end).toSeq
        }.toSeq
    }
  }
} 
Example 175
Source File: DelTransfer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup.converters

import scala.collection.mutable.ArrayBuffer

case class DelTransfer (contig: String,start: Int,len: Int) {
  val endDel: Int = start + len
  def isOverlappingLocus(queryContig:String, queryStart:Int ): Boolean ={
    if(queryContig != contig || queryStart <= start)
      return false
    if (queryStart <= endDel)
      return true
    false
  }
}

class DelContext extends Serializable  {
  private val minDelLen: Int = 0
  val dels: ArrayBuffer[DelTransfer] = new ArrayBuffer[DelTransfer]()

  def add(delTransfer: DelTransfer):Unit = {
    if (delTransfer.len <= minDelLen)
      return
    dels.append(delTransfer)
  }

  def getDelTransferForLocus(contig:String, position: Int): Int = {
    var counter = 0
    for (del <- dels) {
      if (del.isOverlappingLocus(contig, position))
        counter += 1
    }
    counter
  }
} 
Example 176
Source File: MDTagParser.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.pileup

import java.io.File

import htsjdk.samtools.reference.IndexedFastaSequenceFile
import htsjdk.samtools.{Cigar, CigarOperator, SAMRecord}
import org.apache.log4j.Logger
import org.apache.spark.sql.SparkSession
import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter
import org.seqdoop.hadoop_bam.BAMBDGInputFormat

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class MDOperator(length: Int, base: Char) { //S means to skip n positions, not fix needed
  def isDeletion:Boolean = base.isLower
  def isNonDeletion:Boolean = base.isUpper
}
object MDTagParser{

  val logger: Logger = Logger.getLogger(this.getClass.getCanonicalName)
  val pattern = "([0-9]+)\\^?([A-Za-z]+)?".r

  def parseMDTag(t : String) = {

    if (isAllDigits(t)) {
      Array[MDOperator](MDOperator(t.toInt, 'S'))
    }
    else {
      val ab = new ArrayBuffer[MDOperator]()
      val matches = pattern
        .findAllIn(t)
      while (matches.hasNext) {
        val m = matches.next()
        if(m.last.isLetter && !m.contains('^') ){
          val skipPos = m.dropRight(1).toInt
          ab.append(MDOperator(skipPos, 'S') )
          ab.append(MDOperator(0, m.last.toUpper))
        }
        else if (m.last.isLetter && m.contains('^') ){ //encoding deletions as lowercase
          val arr =  m.split('^')
          val skipPos = arr.head.toInt
          ab.append(MDOperator(skipPos, 'S') )
          arr(1).foreach { b =>
            ab.append(MDOperator(0, b.toLower))
          }
        }
        else ab.append(MDOperator(m.toInt, 'S') )
      }
      ab.toArray
    }
  }


  private def isAllDigits(s: String) : Boolean = {
    val len = s.length
    var i = 0
      while(i < len){
        if(! s(i).isDigit ) return false
        i += 1
      }
    true
  }

} 
Example 177
Source File: NCListBuilder.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.NCList

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object NCListBuilder {
  def build[T](array: Array[(Interval[Int], T)]): NCList = {
    val topNCList = NCList(ArrayBuffer.empty[NCList], 0, ArrayBuffer.empty[Int])
    var landingNCList = NCList(ArrayBuffer.empty[NCList], 0, ArrayBuffer.empty[Int])

    val arrayWithIndices = array.zipWithIndex.map{case (k,v) => (v,k)}
    val sortedIndices = arrayWithIndices.sortWith((x, y) => x._2._1.end > y._2._1.end)
      .sortWith((x, y) => x._2._1.start < y._2._1.start)
      .map(x => x._1)

    val stack = mutable.ArrayStack[NCListBuildingStack]()

    sortedIndices.foreach (
      rgid => {
      val currentEnd = arrayWithIndices(rgid)._2._1.end
      while(!stack.isEmpty && arrayWithIndices(stack.top.rgid)._2._1.end < currentEnd)
        stack.pop

      landingNCList = if (stack.isEmpty) topNCList else stack.top.ncList

      val stackElt = appendNCListElt(landingNCList, rgid)
      stack.push(stackElt)
    })

    topNCList
  }

   def appendNCListElt(landingNCList: NCList, rgid: Int): NCListBuildingStack = {
     landingNCList.childrenBuf.append(NCList(ArrayBuffer.empty[NCList], 0, ArrayBuffer.empty[Int]))
     val childrenNCList = landingNCList.childrenBuf.last
     val stackElt = NCListBuildingStack(childrenNCList,rgid)
     landingNCList.rgidBuf.append(rgid)
     landingNCList.nChildren = landingNCList.nChildren+1

     stackElt
   }
} 
Example 178
Source File: NCListTree.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.NCList


import scala.collection.mutable.{ArrayBuffer, ArrayStack}
import scala.util.control.Breaks._

class NCListTree[T](allRegions: Array[(Interval[Int], T)]) extends Serializable {

  val ncList = NCListBuilder.build(allRegions)

  def getAllOverlappings(processedInterval: Interval[Int]) = allOverlappingRegions(processedInterval, ncList, allRegions)

  private def allOverlappingRegions(processedInterval: Interval[Int], topNcList: NCList, intervalArray: Array[(Interval[Int],T)]): List[(Interval[Int], T)] = {
    val backpack = Backpack(intervalArray,processedInterval)
    var resultList = List[(Interval[Int], T)]()
    val walkingStack = ArrayStack[NCListWalkingStack]()

    var n = findLandingChild(topNcList, backpack)
    if (n < 0)
      return Nil

    var ncList = moveToChild(topNcList, n, walkingStack)
    while (ncList != null) {
      val stackElt = peekNCListWalkingStackElt(walkingStack)
      val rgid = stackElt.parentNcList.rgidBuf(stackElt.n)
      breakable {
        val candidateInterval = intervalArray(rgid)
        if (candidateInterval._1.start > backpack.processedInterval.end) {
          
    var n = (n1 + n2) / 2
    while (n != n1) {
      b = base(subset(n))._1.end
      if (b == min)
        return n
      if (b < min)
        n1 = n
      else
        n2 = n

      n = (n1 + n2) / 2
    }
    return n2
  }

  private def moveToChild(parentNcList: NCList, n: Int, walkingStack: ArrayStack[NCListWalkingStack]): NCList = {
    walkingStack.push(NCListWalkingStack(parentNcList, n))
    parentNcList.childrenBuf(n)
  }

  private def peekNCListWalkingStackElt(walkingStack: ArrayStack[NCListWalkingStack]): NCListWalkingStack = {
    walkingStack.top
  }

  private def moveToRightUncle(walkingStack: ArrayStack[NCListWalkingStack]): NCList = {
    val parentNcList = walkingStack.pop().parentNcList
    if (walkingStack.isEmpty)
      return null
    moveToRightSiblingOrUncle(parentNcList, walkingStack)
  }

  private def moveToRightSiblingOrUncle(ncList: NCList, walkingStack: ArrayStack[NCListWalkingStack]): NCList = {
    var ncListLocal = ncList

    do {
      val stackElt = walkingStack.pop()
      if ((stackElt.n+1) < stackElt.parentNcList.nChildren) {
        walkingStack.push(NCListWalkingStack(stackElt.parentNcList,stackElt.n+1))
        ncListLocal = stackElt.parentNcList.childrenBuf(stackElt.n+1)
        return ncListLocal
      } else {
        walkingStack.push(NCListWalkingStack(stackElt.parentNcList,stackElt.n+1))
        ncListLocal = stackElt.parentNcList
        walkingStack.pop()
      }
    } while (walkingStack.nonEmpty)
    null
  }

} 
Example 179
Source File: CoverageUpdate.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.coverage

import org.apache.spark.util.AccumulatorV2

import scala.collection.mutable.ArrayBuffer

case class RightCovEdge(contig: String,
                        minPos: Int,
                        startPoint: Int,
                        cov: Array[Short],
                        cumSum: Short)

case class ContigRange(contig: String, minPos: Int, maxPos: Int)

class CovUpdate(var right: ArrayBuffer[RightCovEdge],
                var left: ArrayBuffer[ContigRange])
    extends Serializable {

  def reset(): Unit = {
    right = new ArrayBuffer[RightCovEdge]()
    left = new ArrayBuffer[ContigRange]()
  }
  def add(p: CovUpdate): CovUpdate = {
    right = right ++ p.right
    left = left ++ p.left
    this
  }

}

class CoverageAccumulatorV2(var covAcc: CovUpdate)
    extends AccumulatorV2[CovUpdate, CovUpdate] {

  def reset(): Unit = {
    covAcc = new CovUpdate(new ArrayBuffer[RightCovEdge](),
                           new ArrayBuffer[ContigRange]())
  }

  def add(v: CovUpdate): Unit = {
    covAcc.add(v)
  }
  def value(): CovUpdate = {
    covAcc
  }
  def isZero(): Boolean = {
    covAcc.right.isEmpty && covAcc.left.isEmpty
  }
  def copy(): CoverageAccumulatorV2 = {
    new CoverageAccumulatorV2(covAcc)
  }
  def merge(other: AccumulatorV2[CovUpdate, CovUpdate]): Unit = {
    covAcc.add(other.value)
  }
} 
Example 180
Source File: BufferBenchmark.scala    From sigmastate-interpreter   with MIT License 5 votes vote down vote up
package special.collections

import debox.Buffer
import spire.syntax.all.cfor
import org.scalameter.api.Bench

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

trait BufferBenchmarkCases extends BenchmarkGens { suite: Bench[Double] =>
  val obj = new Object()
  performance of "append[Int]" in {
    measure method "of debox.Buffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = Buffer.ofSize[Int](16)
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(arr(i))
        }
        val res = buf.toArray()
      }
    }
    measure method "of ArrayBuilder" in {
      using(arrays) in { case (arr, n) =>
        val buf = mutable.ArrayBuilder.make[Int]()
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf += (arr(i))
        }
        val res = buf.result()
      }
    }
    measure method "of ArrayBuffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = ArrayBuffer.empty[Int]
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(arr(i))
        }
        val res = buf.toArray
      }
    }
    measure method "of ListBuffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = ListBuffer.empty[Int]
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(arr(i))
        }
        val res = buf.toList
      }
    }
  }

  performance of "append[Object]" in {
    measure method "of debox.Buffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = Buffer.ofSize[Object](100)
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(obj)
        }
      }
    }
    measure method "of ArrayBuilder" in {
      using(arrays) in { case (arr, n) =>
        val buf = mutable.ArrayBuilder.make[Object]()
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf += (obj)
        }
        val res = buf.result()
      }
    }
    measure method "of ArrayBuffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = ArrayBuffer.empty[Object]
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(obj)
        }
      }
    }
    measure method "of ListBuffer" in {
      using(arrays) in { case (arr, n) =>
        val buf = ListBuffer.empty[Object]
        val limit = arr.length
        cfor(0)(_ < limit, _ + 1) { i =>
          buf.append(obj)
        }
        val res = buf.toList
      }
    }
  }
}

object FastBufferBenchmark extends Bench.LocalTime with BufferBenchmarkCases {
} 
Example 181
Source File: HBaseLocalClient.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.hbase.utilities

import java.io.File

import scala.collection.mutable.ArrayBuffer

import com.google.common.io.Files
import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.hbase.SparkHBaseConf
import org.apache.spark.sql.util._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

import com.paypal.gimel.common.catalog.Field
import com.paypal.gimel.hbase.DataSet

class HBaseLocalClient extends FunSuite with Matchers with BeforeAndAfterAll {

  var sparkSession : SparkSession = _
  var dataSet: DataSet = _
  val hbaseTestingUtility = new HBaseTestingUtility()
  val tableName = "test_table"
  val cfs = Array("personal", "professional")
  val columns = Array("id", "name", "age", "address", "company", "designation", "salary")
  val fields = columns.map(col => new Field(col))

  val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]

  protected override def beforeAll(): Unit = {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    hbaseTestingUtility.startMiniCluster()
    SparkHBaseConf.conf = hbaseTestingUtility.getConfiguration
    createTable(tableName, cfs)
    val conf = new SparkConf
    conf.set(SparkHBaseConf.testConf, "true")
    sparkSession = SparkSession.builder()
      .master("local")
      .appName("HBase Test")
      .config(conf)
      .getOrCreate()

    val listener = new QueryExecutionListener {
      // Only test successful case here, so no need to implement `onFailure`
      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
        metrics += ((funcName, qe, duration))
      }
    }
    sparkSession.listenerManager.register(listener)
    sparkSession.sparkContext.setLogLevel("ERROR")
    dataSet = new DataSet(sparkSession)
  }

  protected override def afterAll(): Unit = {
    hbaseTestingUtility.shutdownMiniCluster()
    sparkSession.close()
  }

  def createTable(name: String, cfs: Array[String]) {
    val tName = Bytes.toBytes(name)
    val bcfs = cfs.map(Bytes.toBytes(_))
    try {
      hbaseTestingUtility.deleteTable(TableName.valueOf(tName))
    } catch {
      case _ : Throwable =>
        println("No table = " + name + " found")
    }
    hbaseTestingUtility.createMultiRegionTable(TableName.valueOf(tName), bcfs)
  }

  // Mocks data for testing
  def mockDataInDataFrame(numberOfRows: Int): DataFrame = {
    def stringed(n: Int) = s"""{"id": "$n","name": "MAC-$n", "address": "MAC-${n + 1}", "age": "${n + 1}", "company": "MAC-$n", "designation": "MAC-$n", "salary": "${n * 10000}" }"""
    val texts: Seq[String] = (1 to numberOfRows).map { x => stringed(x) }
    val rdd: RDD[String] = sparkSession.sparkContext.parallelize(texts)
    val dataFrame: DataFrame = sparkSession.read.json(rdd)
    dataFrame
  }
} 
Example 182
Source File: FriendEntity.scala    From lagom-scala-chirper   with Apache License 2.0 5 votes vote down vote up
package sample.chirper.friend.impl

import akka.Done
import com.lightbend.lagom.scaladsl.persistence.PersistentEntity
import com.lightbend.lagom.scaladsl.playjson.{JsonSerializer, JsonSerializerRegistry}
import sample.chirper.friend.api.User

import scala.collection.mutable.ArrayBuffer

class FriendEntity extends PersistentEntity {

  override type Command = FriendCommand[_]
  override type Event = FriendEvent
  override type State = FriendState

  override def initialState = FriendState(None)

  override def behavior = {
    case FriendState(None) => notInitialized
    case FriendState(Some(user)) => initialized
  }

  val onGetUser = Actions().onReadOnlyCommand[GetUser, GetUserReply] {
    case (GetUser(), ctx, state) => ctx.reply(GetUserReply(state.user))
  }

  val onFriendAdded = Actions().onEvent {
    case (FriendAdded(userId, friendId, timestamp), state) => state.addFriend(friendId)
  }

  val notInitialized = {
    Actions().
      onCommand[CreateUser, Done] {
      case (CreateUser(user), ctx, state) =>
        val events = ArrayBuffer.empty[FriendEvent]
        events += UserCreated(user.userId, user.name)
        events ++= user.friends.map(friendId => FriendAdded(user.userId, friendId))
        ctx.thenPersistAll(events: _*) { () =>
          ctx.reply(Done)
        }
    }.
      onCommand[AddFriend, Done] {
      case (AddFriend(friendUserId), ctx, state) =>
        ctx.invalidCommand(s"User $entityId is not created")
        ctx.done
    }.
      onEvent {
        case (UserCreated(userId, name, timestamp), state) => FriendState(User(userId, name))
      }
  }.orElse(onGetUser).orElse(onFriendAdded)

  val initialized = {
    Actions().
      onCommand[CreateUser, Done] {
      case (CreateUser(user), ctx, state) =>
        ctx.invalidCommand(s"User ${user.name} is already created")
        ctx.done
    }.
      onCommand[AddFriend, Done] {
      case (AddFriend(friendUserId), ctx, state) if state.user.get.friends.contains(friendUserId) =>
        ctx.reply(Done)
        ctx.done
      case (AddFriend(friendUserId), ctx, state) =>
        val event = FriendAdded(state.user.get.userId, friendUserId)
        ctx.thenPersist(event) { _ =>
          ctx.reply(Done)
        }
    }
  }.orElse(onGetUser).orElse(onFriendAdded)

}

object FriendSerializerRegistry extends JsonSerializerRegistry {
  override def serializers = List(
    JsonSerializer[GetUser],
    JsonSerializer[GetUserReply],
    JsonSerializer[FriendState],
    JsonSerializer[CreateUser],
    JsonSerializer[UserCreated],
    JsonSerializer[AddFriend],
    JsonSerializer[FriendAdded]
  )
} 
Example 183
Source File: TestContext.scala    From swave   with Mozilla Public License 2.0 5 votes vote down vote up
package swave.core.internal.testkit

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.scalacheck.rng.Seed
import swave.core.macros._
import swave.core.impl.util.ResizableRingBuffer
import swave.core.util._

private[testkit] final class TestContext(val runNr: Int,
                                         val asyncRate: Double,
                                         val asyncScheduling: TestGeneration.AsyncScheduling,
                                         val genSeed: Seed,
                                         tracing: Boolean) {

  import TestContext._

  private[this] val schedulings = ArrayBuffer.empty[ResizableRingBuffer[Task]]
  val random                    = XorShiftRandom(genSeed.long._1)

  def lastId = schedulings.size - 1

  def nextId(): Int = {
    schedulings += new ResizableRingBuffer[Task](16, 4096)
    schedulings.size - 1
  }

  def trace(msg: ⇒ String)(implicit stage: TestStage): Unit =
    if (tracing) println(stage.toString + ": " + msg)

  def run(msg: ⇒ String)(block: ⇒ Unit)(implicit stage: TestStage): Unit = {
    val scheduled = schedulings(stage.id)
    if (scheduled.nonEmpty || random.decide(asyncRate)) {
      trace("(scheduling) " + msg)
      requireState(scheduled.write(new Task(stage, msg _, block _)))
    } else {
      trace("(sync)       " + msg)
      block
    }
  }

  def hasSchedulings: Boolean = schedulings.exists(_.nonEmpty)

  @tailrec def processSchedulings(): Unit =
    if (hasSchedulings) {
      val snapshot: Array[ResizableRingBuffer[Task]] = schedulings.toArray

      def runSnapshots() = snapshot foreach { buf ⇒
        runTasks(buf, buf.count)
      }

      @tailrec def runTasks(buf: ResizableRingBuffer[Task], count: Int): Unit =
        if (count > 0) {
          val task = buf.read()
          trace("(running)    " + task.msg())(task.stage)
          task.block()
          runTasks(buf, count - 1)
        }

      asyncScheduling match {
        case TestGeneration.AsyncScheduling.InOrder ⇒
          runSnapshots()

        case TestGeneration.AsyncScheduling.RandomOrder ⇒
          random.shuffle_!(snapshot)
          runSnapshots()

        case TestGeneration.AsyncScheduling.ReversedOrder ⇒
          snapshot.reverse_!()
          runSnapshots()

        case TestGeneration.AsyncScheduling.Mixed ⇒
          @tailrec def rec(remaining: Array[ResizableRingBuffer[Task]]): Unit =
            if (remaining.nonEmpty) {
              random.shuffle_!(remaining)
              rec(remaining flatMap { buf ⇒
                val jobsSize = buf.count
                runTasks(buf, random.nextInt(jobsSize + 1)) // at least one, at most all
                if (buf.nonEmpty) buf :: Nil else Nil
              })
            }
          rec(snapshot)
      }
      processSchedulings()
    }
}

private[testkit] object TestContext {

  private class Task(val stage: TestStage, val msg: () ⇒ String, val block: () ⇒ Unit)
} 
Example 184
Source File: ercesiMIPSRunner.scala    From ercesiMIPS   with GNU General Public License v3.0 5 votes vote down vote up
// See LICENSE.txt for license details.

package utils

import scala.collection.mutable.ArrayBuffer
import scala.util.Properties.envOrElse

object ercesiMIPSRunner {
  def apply(ercesiMIPSMap: Map[String, String => Boolean], args: Array[String]): Unit = {
    // Choose the default backend based on what is available.
    lazy val firrtlTerpBackendAvailable: Boolean = {
      try {
        val cls = Class.forName("chisel3.iotesters.FirrtlTerpBackend")
        cls != null
      } catch {
        case e: Throwable => false
      }
    }
    lazy val defaultBackend = if (firrtlTerpBackendAvailable) {
      "firrtl"
    } else {
      ""
    }
    val backendName = envOrElse("TESTER_BACKENDS", defaultBackend).split(" ").head
    val problemsToRun = if(args.isEmpty || args.head == "all" ) {
      ercesiMIPSMap.keys.toSeq.sorted.toArray
    }
    else {
      args
    }

    var successful = 0
    val errors = new ArrayBuffer[String]
    for(testName <- problemsToRun) {
      ercesiMIPSMap.get(testName) match {
        case Some(test) =>
          println(s"Starting ercesiMIPS $testName")
          try {
            if(test(backendName)) {
              successful += 1
            }
            else {
              errors += s"ercesiMIPS $testName: test error occurred"
            }
          }
          catch {
            case exception: Exception =>
              exception.printStackTrace()
              errors += s"ercesiMIPS $testName: exception ${exception.getMessage}"
            case t : Throwable =>
              errors += s"ercesiMIPS $testName: throwable ${t.getMessage}"
          }
        case _ =>
          errors += s"Bad ercesiMIPS name: $testName"
      }

    }
    if(successful > 0) {
      println(s"ercesiMIPSs passing: $successful")
    }
    if(errors.nonEmpty) {
      println("=" * 80)
      println(s"Errors: ${errors.length}: in the following commands")
      println(errors.mkString("\n"))
      println("=" * 80)
      System.exit(1)
    }
  }
} 
Example 185
Source File: GzetPersons.scala    From Mastering-Spark-for-Data-Science   with MIT License 5 votes vote down vote up
package io.gzet.community.util

import scala.collection.mutable.ArrayBuffer

object GzetPersons {
  def buildTuples(array: Array[String]): Array[(String, String)]  = {
    val holdingArray = ArrayBuffer[String]()
    val n = array.length
    val r = 2
    val data = new Array[String](r)
    combinations(array, holdingArray, data, 0, n - 1, 0, r)

    val result = ArrayBuffer[(String, String)]()
    for (s: String <- holdingArray.toArray) {
      val split: Array[String] = s.split(",")
      result += ((split(0), split(1)))
    }
    result.toArray
  }

  def combinations(input: Array[String], result: ArrayBuffer[String], data: Array[String], start: Int, end: Int, index: Int, r: Int): Unit ={
    if(index == r) {
      var s:String = ""
      for (i <- 0 until r) {
        if (i != 0) {
          s += ","
        }
        s += data(i)
      }
      result += s
     return
    }
    var j = start
    while(j <= end && (end - j + 1) >= (r - index)){
      data(index) = input(j)
      combinations(input, result, data, j + 1, end, index + 1, r)
      j += 1
    }
  }
} 
Example 186
Source File: IncrementalSeq.scala    From inox   with Apache License 2.0 5 votes vote down vote up
package inox.utils

import scala.collection.mutable.Builder
import scala.collection.mutable.ArrayBuffer
import scala.collection.{Iterable, IterableLike}

class IncrementalSeq[A] extends IncrementalState
                        with Iterable[A]
                        with IterableLike[A, Seq[A]]
                        with Builder[A, IncrementalSeq[A]] {

  private[this] var stack: List[ArrayBuffer[A]] = List(new ArrayBuffer())

  def clear() : Unit = {
    stack = List(new ArrayBuffer())
  }

  def reset(): Unit = {
    clear()
  }

  def push(): Unit = {
    stack ::= stack.head.clone
  }

  def pop(): Unit = {
    stack = stack.tail
  }

  def iterator = stack.head.toList.iterator
  def +=(e: A) = { stack.head += e; this }
  def -=(e: A) = { stack.head -= e; this }

  override def newBuilder = new scala.collection.mutable.ArrayBuffer()
  def result = this
} 
Example 187
Source File: MatchCollector.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.cep.ops
import dbis.piglet.cep.nfa.NFAStructure
import scala.reflect.ClassTag
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.ArrayBuffer
import dbis.piglet.backends.{SchemaClass => Event}

class MatchCollector[ T <: Event: ClassTag] extends Serializable {
  var macthSequences: ListBuffer[NFAStructure[T]] = new ListBuffer()
  def +(that: NFAStructure[T]): Unit = macthSequences += that
  def size: Int = macthSequences.size
  def convertEventsToArray(): ArrayBuffer[T] = {
    var events: ArrayBuffer[T] = new ArrayBuffer()
    macthSequences.foreach (seq =>  events ++= seq.events)
    events
  }
  def convertEventsToBoolean(): ArrayBuffer[Boolean] = {
    ArrayBuffer(macthSequences.size > 0)
  }
} 
Example 188
Source File: NFAStructure.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.cep.nfa
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.collection.mutable.HashMap
import dbis.piglet.backends.{SchemaClass => Event}
import scala.collection.mutable.ListBuffer


  def addEvent(event: T, currentEdge: ForwardEdge[T]): Unit = {
    events += event
    //if (relatedValue != null) {
     // relatedValue.get(currentEdge.name.get) match {
       // case Some(x) => x.foreach (r => r.updateValue(event))
        //case None => Nil
      //}
    //}
    currenState = currentEdge.destState
    if (currenState.isInstanceOf[FinalState[T]])
      complete = true
  }
  
  override def clone(): NFAStructure[T] = {
    val copyStr = new NFAStructure[T](this.nfaController)
    copyStr.complete = this.complete
    copyStr.currenState = this.currenState
    copyStr.events = this.events.clone()
    //copyStr.events = this.events
    copyStr
  }
} 
Example 189
Source File: FlinkStreamingCEPTest.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.cep.test.flink

import java.io.File

import dbis.piglet.backends.{ Record, SchemaClass }
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.scalatest._
import org.apache.commons.io.FileUtils
import org.apache.flink.api.scala._
import dbis.piglet.cep.nfa._
import dbis.piglet.cep.ops.SelectionStrategy._
import dbis.piglet.cep.ops.OutputStrategy._
import dbis.piglet.cep.flink.CustomDataStreamMatcher._
import scala.collection.mutable.ArrayBuffer
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow
import org.apache.flink.streaming.api.windowing.assigners.GlobalWindows

case class StreamingDoubleRecord(col1: Int, col2: Int) extends java.io.Serializable with SchemaClass {
  override def mkString(delim: String) = s"$col1$delim$col2"
}

object OurStreamingNFA {
    def filter1(record: StreamingDoubleRecord, rvalues: NFAStructure[StreamingDoubleRecord]): Boolean = record.col1 == 1
    def filter2(record: StreamingDoubleRecord, rvalues: NFAStructure[StreamingDoubleRecord]): Boolean = record.col1 == 2
    def filter3(record: StreamingDoubleRecord, rvalues: NFAStructure[StreamingDoubleRecord]): Boolean = record.col1 == 3
    def createNFA = {
      val testNFA: NFAController[StreamingDoubleRecord] = new NFAController()
      val firstState = testNFA.createAndGetStartState("First")
      val secondState = testNFA.createAndGetNormalState("Second")
      val thirdState = testNFA.createAndGetNormalState("Third")
      val finalState = testNFA.createAndGetFinalState("Final")

      val firstEdge = testNFA.createAndGetForwardEdge(filter1)
      val secondEdge = testNFA.createAndGetForwardEdge(filter2)
      val thirdEdge = testNFA.createAndGetForwardEdge(filter3)

      testNFA.createForwardTransition(firstState, firstEdge, secondState)
      testNFA.createForwardTransition(secondState, secondEdge, thirdState)
      testNFA.createForwardTransition(thirdState, thirdEdge, finalState)
      testNFA
    }
  }

class FlinkStreamingCEPTest extends FlatSpec with Matchers with BeforeAndAfterEach {
  var resultArray = new ArrayBuffer[StreamingDoubleRecord]
  override def beforeEach() {
     resultArray.clear()
  }

  val sample = Seq(
      StreamingDoubleRecord(1,1), 
      StreamingDoubleRecord(2,2), 
      StreamingDoubleRecord(1,3), 
      StreamingDoubleRecord(2,4), 
      StreamingDoubleRecord(3,5), 
      StreamingDoubleRecord(1,6),
      StreamingDoubleRecord(2,7),
      StreamingDoubleRecord(3,8))
      
  "Flink Streaming CEP" should "detect the pattern SEQ(A, B, C) with first match" in {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.getConfig.disableSysoutLogging()
    val data = env.fromCollection(sample)
    val res = data.matchNFA(OurStreamingNFA.createNFA, env, FirstMatch)
  }

  it should "detect the pattern SEQ(A, B, C) with any match" in {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.getConfig.disableSysoutLogging()
    val data = env.fromCollection(sample)
    val res = data.matchNFA(OurStreamingNFA.createNFA, env, AllMatches)
  }

  it should "detect the pattern SEQ(A, B, C) with next match" in {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.getConfig.disableSysoutLogging()
    val data = env.fromCollection(sample)
    val res = data.matchNFA(OurStreamingNFA.createNFA, env, NextMatches)
  }

  it should "detect the pattern SEQ(A, B, C) with contiguity match" in {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.getConfig.disableSysoutLogging()
    val data = env.fromCollection(sample)
    val res = data.matchNFA(OurStreamingNFA.createNFA, env, ContiguityMatches)
  }
} 
Example 190
Source File: Cross.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.op

import dbis.piglet.schema._

import scala.collection.mutable.ArrayBuffer


case class Cross(
    private val out: Pipe, 
    private val in: List[Pipe], 
    timeWindow: (Int, String)= null.asInstanceOf[(Int, String)]
  ) extends PigOperator(List(out), in) {

//  require(in.size == 2, "Only two inputs allowed for CROSS, currently!")
  
  override def lineageString: String = {
    s"""CROSS%""" + super.lineageString
  }

  override def constructSchema: Option[Schema] = {
    val newFields = ArrayBuffer[Field]()
    inputs.foreach(p => p.producer.schema match {
      case Some(s) => newFields ++= s.fields map { f =>
        Field(f.name, f.fType, p.name :: f.lineage)
      }
      case None => ???
    })
    schema = Some(Schema(BagType(TupleType(newFields.toArray))))
    schema
  }

  override def toString =
    s"""CROSS
       |  out = ${outPipeNames.mkString(",")}
       |  in = ${inPipeNames.mkString(",")}""".stripMargin

} 
Example 191
Source File: Zip.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.op

import dbis.piglet.schema._

import scala.collection.mutable.ArrayBuffer


case class Zip(
    private val out: Pipe, 
    private val in: List[Pipe],
    withIndex: Boolean
  ) extends PigOperator(List(out), in) {

  require((in.size > 1 && !withIndex) || (in.size == 1 && withIndex), "zip with index works only with one input. Otherwise we must have at least two inputs")

  override def lineageString: String = {
    s"""ZIP%$withIndex""" + super.lineageString
  }

  override def constructSchema: Option[Schema] = {


    val newFields = inputs.flatMap(p => p.producer.schema match {
        case Some(s) =>
          s.fields.map { f =>
            Field(f.name, f.fType, p.name :: f.lineage)
          }
        case None =>
          throw new UnsupportedOperationException(s"Cannot zip with unknown Schema! (input pipe $p)")
      })



    schema = Some(Schema(
      BagType(
        TupleType(
          (if(withIndex) newFields :+ Field("index", Types.LongType) else newFields).toArray
        )
      )
    ))
    schema
  }

  override def toString =
    s"""ZIP
       |  out = ${outPipeNames.mkString(",")}
       |  in = ${inPipeNames.mkString(",")}
       |  withIndex = $withIndex""".stripMargin

} 
Example 192
Source File: SpatialJoin.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.op

import dbis.piglet.expr.SpatialJoinPredicate
import dbis.piglet.op.IndexMethod.IndexMethod
import dbis.piglet.op.PartitionMethod.PartitionMethod
import dbis.piglet.schema._

import scala.collection.mutable.ArrayBuffer



case class SpatialJoin(
    private val out: Pipe, 
    private val in: List[Pipe], 
    predicate: SpatialJoinPredicate,
    index: Option[(IndexMethod, List[String])],
    leftParti: Option[(PartitionMethod, List[String])],
    rightParti: Option[(PartitionMethod, List[String])]
  ) extends PigOperator(List(out), in) {
  
  
  override def lineageString: String = {
    s"""SPATIALJOIN%${predicate.toString()}%$index%""" + super.lineageString
  }

  override def constructSchema: Option[Schema] = {
    val newFields = ArrayBuffer[Field]()
    inputs.foreach(p => p.producer.schema match {
      case Some(s) => if(s.isIndexed) {
        newFields ++= s.element.valueType.asInstanceOf[IndexType] // a bag of Indexes
          .valueType.fields // An Index contains tuples with two fields: indexed column and payload
          .last.fType.asInstanceOf[TupleType] // payload is again a tuple
          .fields // fields in each tuple
          .map { f =>
            Field(f.name, f.fType, p.name :: f.lineage)
          }
        } else {
          newFields ++= s.fields map { f =>
            Field(f.name, f.fType, p.name :: f.lineage)
          }
        }
      case None => newFields += Field("", Types.ByteArrayType)
    })
    schema = Some(Schema(BagType(TupleType(newFields.toArray))))
    schema
  }

  override def toString =
    s"""SPATIALJOIN
       |  out = $outPipeName
       |  in = ${inPipeNames.mkString(",")}
       |  inSchema = {${inputs.map(_.producer.schema).mkString(",")}}
       |  outSchema = $schema
       |  predicate = $predicate
       |  index = $index""".stripMargin
//

} 
Example 193
Source File: Union.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.op

import dbis.piglet.schema._

import scala.collection.mutable.ArrayBuffer


case class Union(private val out: Pipe, private val in: List[Pipe]) extends PigOperator(List(out), in) {

  override def lineageString: String = {
    s"""UNION%""" + super.lineageString
  }

  override def constructSchema: Option[Schema] = {
    val bagType = (p: Pipe) => p.producer.schema.get.element
    val generalizedBagType = (b1: BagType, b2: BagType) => {
      require(b1.valueType.fields.length == b2.valueType.fields.length)
      val newFields = ArrayBuffer[Field]()
      val fieldPairs = b1.valueType.fields.zip(b2.valueType.fields)
      for ((f1, f2) <- fieldPairs) {
        newFields += Field(f1.name, Types.escalateTypes(f1.fType, f2.fType))
      }
      BagType(TupleType(newFields.toArray))
    }

    // case 1: one of the input schema isn't known -> output schema = None
    if (inputs.exists(p => p.producer.schema.isEmpty)) {
      schema = None
    }
    else {
      // case 2: all input schemas have the same number of fields
      val s1 = inputs.head.producer.schema.get
      if (! inputs.tail.exists(p => s1.fields.length != p.producer.schema.get.fields.length)) {
        val typeList = inputs.map(p => bagType(p))
        val resultType = typeList.reduceLeft(generalizedBagType)
        schema = Some(Schema(resultType))
      }
      else {
        // case 3: the number of fields differ
        schema = None
      }
    }
    schema
  }

  override def toString =
    s"""UNION
       |  out = $outPipeName
       |  in = { ${inPipeNames.mkString(",")} }
       |  inSchema = $inputSchema
       |  outSchema = $schema""".stripMargin

} 
Example 194
Source File: JoinEmitter.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.codegen.flink.emitter

import dbis.piglet.codegen.{ CodeEmitter, CodeGenContext, CodeGenException }
import dbis.piglet.expr.Ref
import dbis.piglet.op.Join

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Set
import dbis.piglet.codegen.scala_lang.ScalaEmitter
import scala.collection.mutable.ListBuffer
import dbis.piglet.codegen.flink.FlinkHelper

class JoinEmitter extends dbis.piglet.codegen.scala_lang.JoinEmitter {
  override def template: String = """    val <out> = <rel1><rels, rel1_keys, rel2_keys:{ r,k1, k2 | .join(<r>).where(<k1>).equalTo(<k2>)}>.map{ 
                                    |      t => 
                                    |        val <pairs> = t
                                    |        <class>(<fields>)
                                    |    }""".stripMargin

  override def code(ctx: CodeGenContext, op: Join): String = {
    if (!op.schema.isDefined)
      throw CodeGenException("schema required in JOIN")

    val res = op.inputs.zip(op.fieldExprs)
    val keys = res.map { case (i, k) => k.map { x => s"_${FlinkHelper.getOrderIndex(i.producer.schema, x)}" } }
    var keysGroup: ListBuffer[(List[String], List[String])] = new ListBuffer
    for (i <- 0 until keys.length - 1) {
      val v = (keys(i), keys(i + 1))
      keysGroup += v
    }
    val keysGroup1 = keysGroup.zipWithIndex.map {
      case (i, k) =>
        if (k > 0)
          (FlinkHelper.printQuote(i._1.map { x => s"_$k.$x" }), FlinkHelper.printQuote(i._2))
        else
          (FlinkHelper.printQuote(i._1), FlinkHelper.printQuote(i._2))
    }
    val keys1 = keysGroup1.map(x => x._1)
    val keys2 = keysGroup1.map(x => x._2)

    val className = op.schema match {
      case Some(s) => ScalaEmitter.schemaClassName(s.className)
      case None => ScalaEmitter.schemaClassName(op.outPipeName)
    }
    var pairs = "(v1,v2)"
    for (i <- 3 to op.inputs.length) {
      pairs = s"($pairs,v$i)"
    }
    val fieldList = ArrayBuffer[String]()
    for (i <- 1 to op.inputs.length) {
      op.inputs(i - 1).producer.schema match {
        case Some(s) => fieldList ++= s.fields.zipWithIndex.map { case (f, k) => s"v$i._$k" }
        case None => fieldList += s"v$i._0"
      }
    }
    render(
      Map("out" -> op.outPipeName,
        "rel1" -> op.inputs.head.name,
        "class" -> className,
        "rels" -> op.inputs.tail.map(_.name),
        "pairs" -> pairs,
        "rel1_keys" -> keys1,
        "rel2_keys" -> keys2,
        "fields" -> fieldList.mkString(", ")))
  }
}

object JoinEmitter {
	lazy val instance = new JoinEmitter
} 
Example 195
Source File: FlinkHelper.scala    From piglet   with Apache License 2.0 5 votes vote down vote up
package dbis.piglet.codegen.flink

import dbis.piglet.codegen.CodeGenException
import dbis.piglet.expr.NamedField
import dbis.piglet.expr.PositionalField
import dbis.piglet.schema.Schema
import dbis.piglet.expr.Ref
import dbis.piglet.op.PigOperator
import scala.collection.mutable.ArrayBuffer

object FlinkHelper {
  def getOrderIndex(schema: Option[Schema],
    ref: Ref): Int = schema match {

    case Some(s) => ref match {
      case nf @ NamedField(f, _) => s.indexOfField(nf)
      case PositionalField(pos) => pos
      case _ => 0
    }
    case None => throw new CodeGenException(s"the Flink OrderBy/Join operator needs a schema, thus, invalid field ")
  }
  
  def emitJoinFieldList(node: PigOperator): (String, String) = {
    val rels = node.inputs
    var fields = ""
    var pairs = "(v,w)"
    if (rels.length == 2) {
      val vsize = rels.head.inputSchema.get.fields.length
      fields = node.schema.get.fields.zipWithIndex
        .map { case (f, i) => if (i < vsize) s"v._$i" else s"w._${i - vsize}" }.mkString(", ")
    } else {
      pairs = "(v1,v2)"
      for (i <- 3 to rels.length) {
        pairs = s"($pairs,v$i)"
      }
      val fieldList = ArrayBuffer[String]()
      for (i <- 1 to node.inputs.length) {
        node.inputs(i - 1).producer.schema match {
          case Some(s) => fieldList ++= s.fields.zipWithIndex.map { case (f, k) => s"v$i._$k" }
          case None => fieldList += s"v$i._0"
        }
      }
      fields = fieldList.mkString(", ")
    }
    (pairs, fields)
  }

  def printQuote(values: List[String]) = """"""" + values.mkString("""","""") + """""""
} 
Example 196
Source File: TSNEStandardExample.scala    From dl4scala   with MIT License 5 votes vote down vote up
package org.dl4scala.examples.nlp.tsne

import java.io.File

import org.datavec.api.util.ClassPathResource
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement
import org.deeplearning4j.models.word2vec.wordstore.VocabCache
import org.nd4j.linalg.api.buffer.DataBuffer
import org.nd4j.linalg.api.buffer.util.DataTypeUtil
import org.nd4j.linalg.primitives
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer


object TSNEStandardExample {
  private val log = LoggerFactory.getLogger(TSNEStandardExample.getClass)

  def main(args: Array[String]): Unit = {
    // STEP 1: Initialization
    val iterations = 100
    // create an n-dimensional array of doubles
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE)
    val cacheList = new ArrayBuffer[String](); // cacheList is a dynamic array of strings used to hold all words

    //STEP 2: Turn text input into a list of words
    log.info("Load & Vectorize data....")
    val wordFile = new ClassPathResource("words.txt").getFile //Open the file

    //Get the data of all unique word vectors
    val vectors: primitives.Pair[InMemoryLookupTable[_ <: SequenceElement], VocabCache[_ <: SequenceElement]] = WordVectorSerializer.loadTxt(wordFile)
    val cache = vectors.getSecond
    val weights = vectors.getFirst.getSyn0 //seperate weights of unique words into their own list

    (0 until cache.numWords()).foreach(i => cacheList.append(cache.wordAtIndex(i)))

    import org.deeplearning4j.plot.BarnesHutTsne
    //STEP 3: build a dual-tree tsne to use later//STEP 3: build a dual-tree tsne to use later

    log.info("Build model....")
    val tsne = new BarnesHutTsne.Builder()
      .setMaxIter(iterations)
      .theta(0.5)
      .normalize(false)
      .learningRate(500)
      .useAdaGrad(false)
      .build

    //STEP 4: establish the tsne values and save them to a file
    log.info("Store TSNE Coordinates for Plotting....")
    val outputFile = "target/archive-tmp/tsne-standard-coords.csv"
    new File(outputFile).getParentFile.mkdirs
    tsne.fit(weights)
    tsne.saveAsFile(cacheList.asJava, outputFile)
  }
} 
Example 197
Source File: MNISTVisualizer.scala    From dl4scala   with MIT License 5 votes vote down vote up
package org.dl4scala.examples.feedforward.anomalydetection

import java.awt.{GridLayout, Image}
import java.awt.image.BufferedImage
import javax.swing.{ImageIcon, JFrame, JLabel, JPanel}

import org.nd4j.linalg.api.ndarray.INDArray

import scala.collection.mutable.ArrayBuffer


class MNISTVisualizer(imageScale: Double, digits: ArrayBuffer[INDArray], title: String, gridWidth: Int) {
  def this(imageScale: Double, digits: ArrayBuffer[INDArray], title: String) = {
    this(imageScale, digits, title, 5)
  }

  def visualize(): Unit = {
    val frame = new JFrame
    frame.setTitle(title)
    frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE)
    val panel = new JPanel
    panel.setLayout(new GridLayout(0, gridWidth))
    val list = getComponents
    for (image <- list) {
      panel.add(image)
    }
    frame.add(panel)
    frame.setVisible(true)
    frame.pack()
  }

  def getComponents: ArrayBuffer[JLabel] = {
    val images = new ArrayBuffer[JLabel]()
    for (arr <- digits) {
      val bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY)
      for(i <- 0 until 784) {
        bi.getRaster.setSample(i % 28, i / 28, 0, (255 * arr.getDouble(i)).asInstanceOf[Int])
      }
      val orig = new ImageIcon(bi)
      val imageScaled = orig.getImage.getScaledInstance((imageScale * 28).asInstanceOf[Int],
        (imageScale * 28).asInstanceOf[Int], Image.SCALE_REPLICATE)
      val scaled = new ImageIcon(imageScaled)
      images.append(new JLabel(scaled))
    }
    images
  }
} 
Example 198
Source File: GeneralNetwork.scala    From deepspark   with GNU General Public License v2.0 5 votes vote down vote up
package com.github.nearbydelta.deepspark.network

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import com.github.nearbydelta.deepspark.data._
import com.github.nearbydelta.deepspark.layer.InputLayer
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ParSeq


class GeneralNetwork[In, Out](var inputLayer: InputLayer[In, _]) extends Network[In, Out] {
  @deprecated(message = "This is for kryo deserialization. Please use this(inputlayer)")
  def this() = this(null)

  override def NOut: Int =
    layerSeq.lastOption match {
      case Some(x) ⇒ x.NOut
      case None if inputLayer != null ⇒ inputLayer.NOut
      case None ⇒ 0
    }

  override def backward(error: ParSeq[DataVec]): ArrayBuffer[() ⇒ Unit] = {
    val (upper, fseq) = backwardSeq(error)
    val (x, f) = inputLayer backward upper
    fseq ++= f.seq
    fseq
  }

  override def broadcast(sc: SparkContext): Unit = {
    inputLayer.broadcast(sc)
    super.broadcast(sc)
  }

  override def forward(in: In) = {
    val out = inputLayer.forward(in)
    forwardSingle(out)
  }

  override def forward(in: ParSeq[In]): ParSeq[DataVec] = {
    val out = inputLayer.forward(in)
    forwardSeq(out)
  }

  override def forward(in: RDD[(Long, In)]): RDD[(Long, DataVec)] = {
    val out = inputLayer.forward(in)
    broadcast(in.context)
    forwardRDD(out)
  }

  override def initiateBy(builder: WeightBuilder): this.type = {
    inputLayer.initiateBy(builder)
    super.initiateBy(builder)
    this
  }

  override def loss: Double = super.loss + inputLayer.loss

  override def read(kryo: Kryo, input: Input): Unit = {
    inputLayer = kryo.readClassAndObject(input).asInstanceOf[InputLayer[In, _]]
    super.read(kryo, input)
  }

  override def setUpdatable(bool: Boolean): Network[In, Out] = {
    inputLayer.setUpdatable(bool)
    super.setUpdatable(bool)
  }

  override def unbroadcast(): Unit = {
    inputLayer.unbroadcast()
    super.unbroadcast()
  }

  override def write(kryo: Kryo, output: Output): Unit = {
    kryo.writeClassAndObject(output, inputLayer)
    super.write(kryo, output)
  }
} 
Example 199
Source File: ExtractStageHelpers.scala    From akka-xml-parser   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.akka.xml

import com.fasterxml.aalto.{AsyncByteArrayFeeder, AsyncXMLStreamReader}

import scala.collection.mutable.ArrayBuffer

trait ExtractStageHelpers {

  def update(xmlElementsLst: scala.collection.mutable.Set[XMLGroupElement],
             path: ArrayBuffer[String], newValue: Some[String]): Unit = {
    val elementsWithoutAnyValueForGivenPath = xmlElementsLst.collect {
      case e: XMLGroupElement if (e.xPath == path.toList) && e.value.isEmpty => e
    }

    elementsWithoutAnyValueForGivenPath.map((ele: XMLGroupElement) => {
      xmlElementsLst.remove(ele)
      val newElement = ele.copy(value = newValue)
      xmlElementsLst.add(newElement)
    })
  }

  def getCompletedXMLElements(xmlElementsLst: scala.collection.mutable.Set[XMLGroupElement]):
  scala.collection.mutable.Set[XMLGroupElement] = {
    val completedElements = xmlElementsLst.collect {
      case e if !(e.xPath.nonEmpty && e.value.isEmpty) => e
    }

    completedElements.foreach({
      xmlElementsLst -= _
    })

    completedElements
  }
  
} 
Example 200
Source File: StreamHelper.scala    From akka-xml-parser   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.akka.xml

import com.fasterxml.aalto.{AsyncByteArrayFeeder, AsyncXMLStreamReader}

import scala.collection.mutable.ArrayBuffer


trait StreamHelper {

  def update(xmlElementsLst: scala.collection.mutable.Set[XMLElement],
             path: ArrayBuffer[String], newValue: Some[String]): Unit = {
    val elementsWithoutAnyValueForGivenPath = xmlElementsLst.collect {
      case e: XMLElement if (e.xPath == path.toList) && e.value.isEmpty => e
    }

    elementsWithoutAnyValueForGivenPath.map((ele: XMLElement) => {
      xmlElementsLst.remove(ele)
      val newElement = ele.copy(value = newValue)
      xmlElementsLst.add(newElement)
    })
  }

  def getCompletedXMLElements(xmlElementsLst: scala.collection.mutable.Set[XMLElement]):
  scala.collection.mutable.Set[XMLElement] = {
    val completedElements = xmlElementsLst.collect {
      case e if !(e.xPath.nonEmpty && (e.value.isEmpty && e.attributes.isEmpty)) => e
    }

    completedElements.foreach({
      xmlElementsLst -= _
    })

    completedElements
  }


  def getUpdatedElement(xPath: Seq[String], attributes: Map[String, String], elemText: String)
                       (implicit reader: AsyncXMLStreamReader[AsyncByteArrayFeeder]): String = {
    val prefix = getPrefix

    val startElement = attributes.foldLeft(s"<$prefix${xPath.last}") {
      case (s, (k, v)) => s"""$s $k="$v""""
    } + ">"
    val value = elemText
    val endElement = getEndElement(xPath, prefix)
    s"$startElement$value$endElement"
  }

  private def getPrefix(implicit reader: AsyncXMLStreamReader[AsyncByteArrayFeeder]): String = Option(reader.getPrefix) match {
    case Some(pre) if pre.nonEmpty => s"$pre:"
    case _ => ""
  }

  private def getEndElement(xPath: Seq[String], prefix: String) = s"</$prefix${xPath.last}>"

}