com.typesafe.scalalogging.StrictLogging Scala Examples

The following examples show how to use com.typesafe.scalalogging.StrictLogging. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: PulsarSinkTask.scala    From stream-reactor   with Apache License 2.0 7 votes vote down vote up
package com.datamountaineer.streamreactor.connect.pulsar.sink

import java.util
import java.util.UUID

import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.pulsar.config.{PulsarConfigConstants, PulsarSinkConfig, PulsarSinkSettings}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._



  override def stop(): Unit = {
    logger.info("Stopping Pulsar sink.")
    writer.foreach(w => w.close)
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    writer.foreach(w => w.flush)
  }

  override def version: String = manifest.version()
} 
Example 2
Source File: ElasticSinkTask.scala    From stream-reactor   with Apache License 2.0 6 votes vote down vote up
package com.datamountaineer.streamreactor.connect.elastic6

import java.util

import com.datamountaineer.streamreactor.connect.elastic6.config.{ElasticConfig, ElasticConfigConstants, ElasticSettings}
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._

class ElasticSinkTask extends SinkTask with StrictLogging {
  private var writer: Option[ElasticJsonWriter] = None
  private val progressCounter = new ProgressCounter
  private var enableProgress: Boolean = false
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  
  override def stop(): Unit = {
    logger.info("Stopping Elastic sink.")
    writer.foreach(w => w.close())
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    logger.info("Flushing Elastic Sink")
  }

  override def version: String = manifest.version()
} 
Example 3
Source File: KuduSinkTask.scala    From stream-reactor   with Apache License 2.0 6 votes vote down vote up
package com.datamountaineer.streamreactor.connect.kudu.sink

import java.util

import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.kudu.config.{KuduConfig, KuduConfigConstants, KuduSettings}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._


  override def stop(): Unit = {
    logger.info("Stopping Kudu sink.")
    writer.foreach(w => w.close())
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    writer.foreach(w => w.flush())
  }

  override def version: String = manifest.version()
} 
Example 4
Source File: JMSSinkTask.scala    From stream-reactor   with Apache License 2.0 6 votes vote down vote up
package com.datamountaineer.streamreactor.connect.jms.sink

import java.util

import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.jms.config.{JMSConfig, JMSConfigConstants, JMSSettings}
import com.datamountaineer.streamreactor.connect.jms.sink.writer.JMSWriter
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._


  override def stop(): Unit = {
    logger.info("Stopping JMS sink.")
    writer.foreach(w => w.close())
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    //TODO
    //have the writer expose a is busy; can expose an await using a countdownlatch internally
  }

  override def version: String = manifest.version()
} 
Example 5
Source File: CoapSourceTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.coap.source

import java.util
import java.util.concurrent.LinkedBlockingQueue

import com.datamountaineer.streamreactor.connect.coap.configs.{CoapConstants, CoapSettings, CoapSourceConfig}
import com.datamountaineer.streamreactor.connect.queues.QueueHelpers
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}

import scala.collection.JavaConverters._


class CoapSourceTask extends SourceTask with StrictLogging {
  private var readers: Set[CoapReader] = _
  private val progressCounter = new ProgressCounter
  private var enableProgress: Boolean = false
  private val queue = new LinkedBlockingQueue[SourceRecord]()
  private var batchSize: Int = CoapConstants.BATCH_SIZE_DEFAULT
  private var lingerTimeout = CoapConstants.SOURCE_LINGER_MS_DEFAULT
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(scala.io.Source.fromInputStream(getClass.getResourceAsStream("/coap-source-ascii.txt")).mkString + s" $version")
    logger.info(manifest.printManifest())

    val conf = if (context.configs().isEmpty) props else context.configs()

    val config = CoapSourceConfig(conf)
    enableProgress = config.getBoolean(CoapConstants.PROGRESS_COUNTER_ENABLED)
    val settings = CoapSettings(config)
    batchSize = config.getInt(CoapConstants.BATCH_SIZE)
    lingerTimeout = config.getInt(CoapConstants.SOURCE_LINGER_MS)
    enableProgress = config.getBoolean(CoapConstants.PROGRESS_COUNTER_ENABLED)
    readers = CoapReaderFactory(settings, queue)
  }

  override def poll(): util.List[SourceRecord] = {
    val records = new util.ArrayList[SourceRecord]()

    QueueHelpers.drainWithTimeoutNoGauva(records, batchSize, lingerTimeout * 1000000 , queue)

    if (enableProgress) {
      progressCounter.update(records.asScala.toVector)
    }
    records
  }

  override def stop(): Unit = {
    logger.info("Stopping Coap source and closing connections.")
    readers.foreach(_.stop())
    progressCounter.empty
  }

  override def version: String = manifest.version()
} 
Example 6
Source File: Main.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.slotting

import com.amazonaws.services.autoscaling.AmazonAutoScaling
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB
import com.amazonaws.services.ec2.AmazonEC2
import com.google.inject.AbstractModule
import com.google.inject.Module
import com.google.inject.Provides
import com.google.inject.multibindings.Multibinder
import com.netflix.iep.aws.AwsClientFactory
import com.netflix.iep.guice.BaseModule
import com.netflix.iep.guice.GuiceHelper
import com.netflix.iep.service.Service
import com.netflix.iep.service.ServiceManager
import com.netflix.spectator.api.NoopRegistry
import com.netflix.spectator.api.Registry
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import com.typesafe.scalalogging.StrictLogging
import javax.inject.Singleton

object Main extends StrictLogging {

  private def isLocalEnv: Boolean = !sys.env.contains("EC2_INSTANCE_ID")

  private def getBaseModules: java.util.List[Module] = {
    val modules = {
      GuiceHelper.getModulesUsingServiceLoader
    }

    if (isLocalEnv) {
      // If we are running in a local environment, provide simple versions of registry
      // and config bindings. These bindings are normally provided by the final package
      // config for the app in the production setup.
      modules.add(new AbstractModule {
        override def configure(): Unit = {
          bind(classOf[Registry]).toInstance(new NoopRegistry)
          bind(classOf[Config]).toInstance(ConfigFactory.load())
        }
      })
    }

    modules
  }

  def main(args: Array[String]): Unit = {
    try {
      val modules = getBaseModules
      modules.add(new ServerModule)

      val guice = new GuiceHelper
      guice.start(modules)
      guice.getInjector.getInstance(classOf[ServiceManager])
      guice.addShutdownHook()
    } catch {
      // Send exceptions to main log file instead of wherever STDERR is sent for the process
      case t: Throwable => logger.error("fatal error on startup", t)
    }
  }

  class ServerModule extends BaseModule {
    override def configure(): Unit = {
      val serviceBinder = Multibinder.newSetBinder(binder(), classOf[Service])
      serviceBinder.addBinding().to(classOf[SlottingService])
    }

    @Provides
    @Singleton
    protected def providesAmazonDynamoDB(factory: AwsClientFactory): AmazonDynamoDB = {
      factory.getInstance(classOf[AmazonDynamoDB])
    }

    @Provides
    @Singleton
    protected def providesAmazonEC2(factory: AwsClientFactory): AmazonEC2 = {
      factory.getInstance(classOf[AmazonEC2])
    }

    @Provides
    @Singleton
    protected def providesAmazonAutoScaling(factory: AwsClientFactory): AmazonAutoScaling = {
      factory.getInstance(classOf[AmazonAutoScaling])
    }
  }
} 
Example 7
Source File: ConfigManager.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.iep.lwc

import akka.stream.Attributes
import akka.stream.FlowShape
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.stage.GraphStage
import akka.stream.stage.GraphStageLogic
import akka.stream.stage.InHandler
import akka.stream.stage.OutHandler
import com.netflix.iep.lwc.ForwardingService.Message
import com.netflix.iep.lwc.fwd.cw.ClusterConfig
import com.typesafe.scalalogging.StrictLogging

class ConfigManager
    extends GraphStage[FlowShape[Message, Map[String, ClusterConfig]]]
    with StrictLogging {

  private val in = Inlet[Message]("ConfigManager.in")
  private val out = Outlet[Map[String, ClusterConfig]]("ConfigManager.out")

  override val shape: FlowShape[Message, Map[String, ClusterConfig]] = FlowShape(in, out)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = {
    new GraphStageLogic(shape) with InHandler with OutHandler {

      private val configs = scala.collection.mutable.AnyRefMap.empty[String, ClusterConfig]

      override def onPush(): Unit = {
        val msg = grab(in)
        if (msg.response.isUpdate) {
          val cluster = msg.cluster
          try {
            configs += cluster -> msg.response.clusterConfig
            logger.info(s"updated configuration for cluster $cluster")
          } catch {
            case e: Exception =>
              logger.warn(s"invalid config for cluster $cluster", e)
          }
        } else {
          configs -= msg.cluster
          logger.info(s"deleted configuration for cluster ${msg.cluster}")
        }
        push(out, configs.toMap)
      }

      override def onPull(): Unit = {
        pull(in)
      }

      override def onUpstreamFinish(): Unit = {
        completeStage()
      }

      setHandlers(in, out, this)
    }
  }
} 
Example 8
Source File: UpdateApi.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.aggregator

import javax.inject.Inject
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.model.HttpEntity
import akka.http.scaladsl.model.HttpResponse
import akka.http.scaladsl.model.MediaTypes
import akka.http.scaladsl.model.StatusCode
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Route
import com.fasterxml.jackson.core.JsonParser
import com.netflix.atlas.akka.CustomDirectives._
import com.netflix.atlas.akka.WebApi
import com.netflix.atlas.core.validation.ValidationResult
import com.netflix.atlas.eval.stream.Evaluator
import com.typesafe.scalalogging.StrictLogging

class UpdateApi @Inject()(
  evaluator: Evaluator,
  aggrService: AtlasAggregatorService
) extends WebApi
    with StrictLogging {

  import UpdateApi._

  require(aggrService != null, "no binding for aggregate registry")

  def routes: Route = {
    endpointPath("api" / "v4" / "update") {
      post {
        parseEntity(customJson(p => processPayload(p, aggrService))) { response =>
          complete(response)
        }
      }
    }
  }
}

object UpdateApi {
  private val decoder = PayloadDecoder.default

  private[aggregator] def processPayload(
    parser: JsonParser,
    service: AtlasAggregatorService
  ): HttpResponse = {
    val result = decoder.decode(parser, service)
    createResponse(result.numDatapoints, result.failures)
  }

  private val okResponse = {
    val entity = HttpEntity(MediaTypes.`application/json`, "{}")
    HttpResponse(StatusCodes.OK, entity = entity)
  }

  private def createErrorResponse(status: StatusCode, msg: FailureMessage): HttpResponse = {
    val entity = HttpEntity(MediaTypes.`application/json`, msg.toJson)
    HttpResponse(status, entity = entity)
  }

  private def createResponse(numDatapoints: Int, failures: List[ValidationResult]): HttpResponse = {
    if (failures.isEmpty) {
      okResponse
    } else {
      val numFailures = failures.size
      if (numDatapoints > numFailures) {
        // Partial failure
        val msg = FailureMessage.partial(failures, numFailures)
        createErrorResponse(StatusCodes.Accepted, msg)
      } else {
        // All datapoints dropped
        val msg = FailureMessage.error(failures, numFailures)
        createErrorResponse(StatusCodes.BadRequest, msg)
      }
    }
  }
} 
Example 9
Source File: SchemaRegistryService.scala    From kafka-testing   with Apache License 2.0 5 votes vote down vote up
package com.landoop.kafka.testing

import java.net.{Socket, SocketException}
import java.util.Properties

import com.typesafe.scalalogging.StrictLogging
import io.confluent.kafka.schemaregistry.avro.AvroCompatibilityLevel
import io.confluent.kafka.schemaregistry.client.rest.RestService
import io.confluent.kafka.schemaregistry.rest.{SchemaRegistryConfig, SchemaRegistryRestApplication}
import io.confluent.kafka.schemaregistry.storage.{SchemaRegistry, SchemaRegistryIdentity}
import org.eclipse.jetty.server.Server

class SchemaRegistryService(val port: Int,
                            val zookeeperConnection: String,
                            val kafkaTopic: String,
                            val avroCompatibilityLevel: AvroCompatibilityLevel,
                            val masterEligibility: Boolean) extends StrictLogging {

  private val app = new SchemaRegistryRestApplication({
    val prop = new Properties
    prop.setProperty("port", port.asInstanceOf[Integer].toString)
    prop.setProperty(SchemaRegistryConfig.KAFKASTORE_CONNECTION_URL_CONFIG, zookeeperConnection)
    prop.put(SchemaRegistryConfig.KAFKASTORE_TOPIC_CONFIG, kafkaTopic)
    prop.put(SchemaRegistryConfig.COMPATIBILITY_CONFIG, avroCompatibilityLevel.toString)
    prop.put(SchemaRegistryConfig.MASTER_ELIGIBILITY, masterEligibility.asInstanceOf[AnyRef])
    prop
  })

  val restServer = startServer(port)

  var Endpoint: String = getEndpoint(restServer)

  val restClient = new RestService(Endpoint)

  def startServer(port: Int, retries: Int = 5): Option[Server] = {
    var retry = retries > 0
    var restServer: Option[Server] = None
    if (retry) {
      if (isPortInUse(port)) {
        logger.info(s"Schema Registry Port $port is already in use")
        Thread.sleep(2000)
        startServer(port, retries - 1)
      } else {
        restServer = Some(app.createServer)
        restServer.get.start()
      }
    }
    restServer
  }

  def getEndpoint(restServer: Option[Server]): String = {
    if (restServer.isDefined) {
      val uri = restServer.get.getURI.toString
      if (uri.endsWith("/")) {
        uri.substring(0, uri.length - 1)
      } else {
        uri
      }
    } else ""
  }

  private def isPortInUse(port: Integer): Boolean = try {
    new Socket("127.0.0.1", port).close()
    true
  }
  catch {
    case e: SocketException => false
  }

  def close() {
    if (restServer.isDefined) {
      restServer.get.stop()
      restServer.get.join()
    }
  }

  def isMaster: Boolean = app.schemaRegistry.isMaster

  def setMaster(schemaRegistryIdentity: SchemaRegistryIdentity): Unit =
    app.schemaRegistry.setMaster(schemaRegistryIdentity)

  def myIdentity: SchemaRegistryIdentity = app.schemaRegistry.myIdentity

  def masterIdentity: SchemaRegistryIdentity = app.schemaRegistry.masterIdentity

  def schemaRegistry: SchemaRegistry = app.schemaRegistry
} 
Example 10
Source File: RegExTest.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source

import com.datamountaineer.streamreactor.connect.ftp.source
import com.typesafe.scalalogging.StrictLogging
import org.apache.commons.net.ftp.FTPFile
import org.mockito.MockitoSugar
import org.scalatest.BeforeAndAfter
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers


class RegExTest extends AnyFunSuite with Matchers with BeforeAndAfter with StrictLogging with MockitoSugar  {
  def mockFile(name: String) = {
    val f = mock[FTPFile]
    when(f.isFile).thenReturn(true)
    when(f.isDirectory).thenReturn(false)
    when(f.getName()).thenReturn(name)
    f
  }
  test("Matches RegEx"){
    FtpSourceConfig.fileFilter -> ".*"

    var f : source.AbsoluteFtpFile = new AbsoluteFtpFile(mockFile("file.txt"),"\\");
    f.name.matches(".*") shouldBe true
    f.name.matches("a") shouldBe false
  }
} 
Example 11
Source File: ManyFilesTest.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source


import com.typesafe.scalalogging.StrictLogging
import org.scalatest.BeforeAndAfter
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

import scala.collection.JavaConverters._


class ManyFilesTest extends AnyFunSuite with Matchers with BeforeAndAfter with StrictLogging {
  val ftpServer = new EmbeddedFtpServer(3333)

  val fileCount = 132
  val sliceSize = 1024
  val maxPollRecords = 74

  val lineSep = System.getProperty("line.separator")
  val fileContent = (1 to 12000).map(index => s"line_${"%010d".format(index)}${lineSep}").mkString.getBytes

  val fileName = "the_file_name"
  val filePath = s"/folder/${fileName}"

  val sourceConfig = Map(
    FtpSourceConfig.Address -> s"${ftpServer.host}:${ftpServer.port}",
    FtpSourceConfig.User -> ftpServer.username,
    FtpSourceConfig.Password -> ftpServer.password,
    FtpSourceConfig.RefreshRate -> "PT0S",
    FtpSourceConfig.MonitorTail -> "/folder/:output_topic",
    FtpSourceConfig.MonitorSliceSize -> sliceSize.toString,
    FtpSourceConfig.FileMaxAge -> "P7D",
    FtpSourceConfig.KeyStyle -> "string",
    FtpSourceConfig.fileFilter -> ".*",
    FtpSourceConfig.FtpMaxPollRecords -> s"${maxPollRecords}",
    FtpSourceConfig.KeyStyle -> "struct"
  )


  test("Read only FtpMaxPollRecords even if using MonitorSliceSize") {
    val fs = new FileSystem(ftpServer.rootDir).clear()
    val cfg = new FtpSourceConfig(sourceConfig.asJava)
    val offsets = new DummyOffsetStorage
    (0 to fileCount).map(index => fs.applyChanges(Seq(s"${filePath}_${index}" -> Append(fileContent))))

    val poller = new FtpSourcePoller(cfg, offsets)
    ftpServer.start()
    val slices = poller.poll()
    (slices.size) shouldBe (maxPollRecords)
    ftpServer.stop()
  }
} 
Example 12
Source File: FtpFileListerTest.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source

import com.typesafe.scalalogging.StrictLogging
import org.apache.commons.net.ftp.{FTPClient, FTPFile}
import org.mockito.MockitoSugar
import org.scalatest.BeforeAndAfter
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers



    val dira = mockDir("dira")
    val dirb = mockDir("dirb")
    val thisDir = mockDir(".")
    val parentDir = mockDir("..")
    when(ftp.listFiles("/a/")).thenReturn(Array[FTPFile](dira, dirb, thisDir, parentDir))

    val path = mockDir("path")
    when(ftp.listFiles("/a/dira/")).thenReturn(Array[FTPFile](path, thisDir, parentDir))

    val file1 = mockFile("file1.txt")
    when(ftp.listFiles("/a/dira/path/")).thenReturn(Array[FTPFile](file1, thisDir, parentDir))

    val nopath = mockDir("nopath")
    when(ftp.listFiles("/a/dirb/")).thenReturn(Array[FTPFile](nopath, path, thisDir, parentDir))
    when(ftp.listFiles("/a/dirb/nopath/")).thenThrow(new RuntimeException("Should not list this directory"))

    val file3 = mockFile("file3.txt")
    val file4 = mockFile("file4.csv")
    when(ftp.listFiles("/a/dirb/path/")).thenReturn(Array[FTPFile](file3, file4, thisDir, parentDir))

    FtpFileLister(ftp).listFiles("/a/dir?/path/*.txt").toList should contain theSameElementsAs Seq(
      AbsoluteFtpFile(file1, "/a/dira/path/"),
      AbsoluteFtpFile(file3, "/a/dirb/path/")
    )
  }
} 
Example 13
Source File: FtpFileLister.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source

import java.nio.file.{FileSystems, Paths}
import java.time.{Duration, Instant}

import com.typesafe.scalalogging.StrictLogging
import org.apache.commons.net.ftp.{FTPClient, FTPFile}

// org.apache.commons.net.ftp.FTPFile only contains the relative path
case class AbsoluteFtpFile(ftpFile:FTPFile, parentDir:String) {
  def name() = ftpFile.getName
  def size() = ftpFile.getSize
  def timestamp() = ftpFile.getTimestamp.toInstant
  def path() = Paths.get(parentDir, name).toString
  def age(): Duration = Duration.between(timestamp, Instant.now)
}

case class FtpFileLister(ftp: FTPClient) extends StrictLogging {

  def pathMatch(pattern: String, path: String):Boolean = {
    val g = s"glob:$pattern"
    FileSystems.getDefault.getPathMatcher(g).matches(Paths.get(path))
  }

  def isGlobPattern(pattern: String): Boolean = List("*", "?", "[", "{").exists(pattern.contains(_))

  def listFiles(path: String) : Seq[AbsoluteFtpFile] = {
    val pathParts : Seq[String] = path.split("/")

    val (basePath, patterns) = pathParts.zipWithIndex.view.find{case (part, _) => isGlobPattern(part)} match {
      case Some((_, index)) => pathParts.splitAt(index)
      case _ => (pathParts.init, Seq[String](pathParts.last))
    }

    def iter(basePath: String, patterns: List[String]) : Seq[AbsoluteFtpFile] = {
      Option(ftp.listFiles(basePath + "/")) match {
        case Some(files) => patterns match {
          case pattern :: Nil => {
            files.filter(f => f.isFile && pathMatch(pattern, f.getName))
              .map(AbsoluteFtpFile(_, basePath + "/"))
          }
          case pattern :: rest => {
            files.filter(f => f.getName() != "." && f.getName() != ".." && pathMatch(pattern, f.getName))
              .flatMap(f => iter(Paths.get(basePath, f.getName).toString, rest))
          }
          case _ => Seq()
        }
        case _ => Seq()
      }
    }

    iter(Paths.get("/", basePath:_*).toString, patterns.toList)
  }
} 
Example 14
Source File: ConnectFileMetaDataStore.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source

import java.time.Instant
import java.util

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.storage.OffsetStorageReader

import scala.collection.JavaConverters._
import scala.collection.mutable

// allows storage and retrieval of meta datas into connect framework
class ConnectFileMetaDataStore(offsetStorage: OffsetStorageReader) extends FileMetaDataStore with StrictLogging {
  // connect offsets aren't directly committed, hence we'll cache them
  private val cache = mutable.Map[String, FileMetaData]()

  override def get(path: String): Option[FileMetaData] =
    cache.get(path).orElse({
      val stored = getFromStorage(path)
      stored.foreach(set(path,_))
      stored
    })

  override def set(path: String, fileMetaData: FileMetaData): Unit = {
    logger.debug(s"ConnectFileMetaDataStore path = ${path}, fileMetaData.offset = ${fileMetaData.offset}, fileMetaData.attribs.size = ${fileMetaData.attribs.size}")
    cache.put(path, fileMetaData)
  }

  // cache couldn't provide us the info. this is a rather expensive operation (?)
  def getFromStorage(path: String): Option[FileMetaData] =
    offsetStorage.offset(Map("path" -> path).asJava) match {
      case null =>
        logger.info(s"meta store storage HASN'T ${path}")
        None
      case o =>
        logger.info(s"meta store storage has ${path}")
        Some(connectOffsetToFileMetas(path, o))
    }

  def fileMetasToConnectPartition(meta:FileMetaData): util.Map[String, String] = {
    Map("path" -> meta.attribs.path).asJava
  }

  def connectOffsetToFileMetas(path:String, o:AnyRef): FileMetaData = {
    val jm = o.asInstanceOf[java.util.Map[String, AnyRef]]
    FileMetaData(
      FileAttributes(
        path,
        jm.get("size").asInstanceOf[Long],
        Instant.ofEpochMilli(jm.get("timestamp").asInstanceOf[Long])
      ),
      jm.get("hash").asInstanceOf[String],
      Instant.ofEpochMilli(jm.get("firstfetched").asInstanceOf[Long]),
      Instant.ofEpochMilli(jm.get("lastmodified").asInstanceOf[Long]),
      Instant.ofEpochMilli(jm.get("lastinspected").asInstanceOf[Long]),
      jm.asScala.getOrElse("offset", -1L).asInstanceOf[Long]
    )
  }

  def fileMetasToConnectOffset(meta: FileMetaData): util.Map[String, Any] = {
    Map("size" -> meta.attribs.size,
      "timestamp" -> meta.attribs.timestamp.toEpochMilli,
      "hash" -> meta.hash,
      "firstfetched" -> meta.firstFetched.toEpochMilli,
      "lastmodified" -> meta.lastModified.toEpochMilli,
      "lastinspected" -> meta.lastInspected.toEpochMilli,
      "offset" -> meta.offset
    ).asJava
  }
} 
Example 15
Source File: FtpSourceConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.ftp.source

import java.util

import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.source.SourceConnector

import scala.collection.JavaConverters._
import scala.util.{Failure, Try}

class FtpSourceConnector extends SourceConnector with StrictLogging {
  private var configProps : Option[util.Map[String, String]] = None
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def taskClass(): Class[_ <: Task] = classOf[FtpSourceTask]

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Setting task configurations for $maxTasks workers.")
    configProps match {
      case Some(props) => (1 to maxTasks).map(_ => props).toList.asJava
      case None => throw new ConnectException("cannot provide taskConfigs without being initialised")
    }
  }

  override def stop(): Unit = {
    logger.info("stop")
  }

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(scala.io.Source.fromInputStream(getClass.getResourceAsStream("/ftp-source-ascii.txt")).mkString + s" $version")
    logger.info(s"start FtpSourceConnector")

    configProps = Some(props)
    Try(new FtpSourceConfig(props)) match {
      case Failure(f) => throw new ConnectException("Couldn't start due to configuration error: " + f.getMessage, f)
      case _ =>
    }
  }

  override def version(): String = manifest.version()

  override def config() = FtpSourceConfig.definition
} 
Example 16
Source File: Retries.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.voltdb.writers

import com.typesafe.scalalogging.StrictLogging

import scala.util.Try

trait Retries extends StrictLogging {
  def withRetries[T](retries: Int, retryInterval: Long, errorMessage: Option[String])(thunk: => T): T = {
    try {
      thunk
    }
    catch {
      case t: Throwable =>
        errorMessage.foreach(m => logger.error(m, t))
        if (retries - 1 <= 0) throw t
        Try(Thread.sleep(retryInterval))
        withRetries(retries - 1, retryInterval, errorMessage)(thunk)
    }
  }
} 
Example 17
Source File: CoapReaderFactory.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.coap.source

import java.util
import java.util.concurrent.LinkedBlockingQueue

import com.datamountaineer.streamreactor.connect.coap.configs.CoapSetting
import com.datamountaineer.streamreactor.connect.coap.connection.CoapManager
import com.datamountaineer.streamreactor.connect.coap.domain.CoapMessageConverter
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.source.SourceRecord
import org.eclipse.californium.core.{CoapHandler, CoapObserveRelation, CoapResponse, WebLink}


class MessageHandler(resource: String, topic: String, queue: LinkedBlockingQueue[SourceRecord]) extends CoapHandler with StrictLogging {
  val converter = CoapMessageConverter()

  override def onError(): Unit = {
    logger.warn(s"Message dropped for $topic!")
  }

  override def onLoad(response: CoapResponse): Unit = {
    val records = converter.convert(resource, topic, response.advanced())
    logger.debug(s"Received ${response.advanced().toString} for $topic")
    logger.debug(s"Records in queue ${queue.size()} for $topic")
    queue.put(records)
  }
} 
Example 18
Source File: DTLSConnectionFn.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.coap.connection

import java.io.FileInputStream
import java.net.{ConnectException, InetAddress, InetSocketAddress, URI}
import java.security.cert.Certificate
import java.security.{KeyStore, PrivateKey}

import com.datamountaineer.streamreactor.connect.coap.configs.{CoapConstants, CoapSetting}
import com.typesafe.scalalogging.StrictLogging
import org.eclipse.californium.core.CoapClient
import org.eclipse.californium.core.coap.CoAP
import org.eclipse.californium.core.network.CoapEndpoint
import org.eclipse.californium.core.network.config.NetworkConfig
import org.eclipse.californium.scandium.DTLSConnector
import org.eclipse.californium.scandium.config.DtlsConnectorConfig
import org.eclipse.californium.scandium.dtls.cipher.CipherSuite
import org.eclipse.californium.scandium.dtls.pskstore.InMemoryPskStore


  def discoverServer(address: String, uri: URI): URI = {
    val client = new CoapClient(s"${uri.getScheme}://$address:${uri.getPort.toString}/.well-known/core")
    client.useNONs()
    val response = client.get()

    if (response != null) {
      logger.info(s"Discovered Server ${response.advanced().getSource.toString}.")
      new URI(uri.getScheme,
        uri.getUserInfo,
        response.advanced().getSource.getHostName,
        response.advanced().getSourcePort,
        uri.getPath,
        uri.getQuery,
        uri.getFragment)
    } else {
      logger.error(s"Unable to find any servers on local network with multicast address $address.")
      throw new ConnectException(s"Unable to find any servers on local network with multicast address $address.")
    }
  }
} 
Example 19
Source File: CoapSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.coap.sink

import java.util

import com.datamountaineer.streamreactor.connect.coap.configs.{CoapConstants, CoapSinkConfig}
import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._


class CoapSinkConnector extends SinkConnector with StrictLogging {
  private var configProps: util.Map[String, String] = _
  private val configDef = CoapSinkConfig.config
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def taskClass(): Class[_ <: Task] = classOf[CoapSinkTask]
  override def start(props: util.Map[String, String]): Unit = {
    Helpers.checkInputTopics(CoapConstants.COAP_KCQL, props.asScala.toMap)
    configProps = props
  }

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Setting task configurations for $maxTasks workers.")
    (1 to maxTasks).map(_ => configProps).toList.asJava
  }

  override def stop(): Unit = {}
  override def config(): ConfigDef = configDef
  override def version(): String = manifest.version()
} 
Example 20
Source File: CoapSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.coap.sink

import java.util

import com.datamountaineer.streamreactor.connect.coap.configs.{CoapConstants, CoapSettings, CoapSinkConfig}
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._
import scala.collection.mutable


class CoapSinkTask extends SinkTask with StrictLogging {
  private val writers = mutable.Map.empty[String, CoapWriter]
  private val progressCounter = new ProgressCounter
  private var enableProgress: Boolean = false
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(scala.io.Source.fromInputStream(getClass.getResourceAsStream("/coap-sink-ascii.txt")).mkString + s" $version")
    logger.info(manifest.printManifest())

    val conf = if (context.configs().isEmpty) props else context.configs()

    val sinkConfig = CoapSinkConfig(conf)
    enableProgress = sinkConfig.getBoolean(CoapConstants.PROGRESS_COUNTER_ENABLED)
    val settings = CoapSettings(sinkConfig)

    //if error policy is retry set retry interval
    if (settings.head.errorPolicy.getOrElse(ErrorPolicyEnum.THROW).equals(ErrorPolicyEnum.RETRY)) {
      context.timeout(sinkConfig.getString(CoapConstants.ERROR_RETRY_INTERVAL).toLong)
    }
    settings.map(s => (s.kcql.getSource, CoapWriter(s))).map({ case (k, v) => writers.put(k, v) })
  }

  override def put(records: util.Collection[SinkRecord]): Unit = {
    records.asScala.map(r => writers(r.topic()).write(List(r)))
    val seq = records.asScala.toVector
    if (enableProgress) {
      progressCounter.update(seq)
    }
  }

  override def stop(): Unit = {
    writers.foreach({ case (t, w) =>
      logger.info(s"Shutting down writer for $t")
      w.stop()
    })
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {}

  override def version: String = manifest.version()

} 
Example 21
Source File: HazelCastWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hazelcast.writers

import java.util.concurrent.Executors

import com.datamountaineer.streamreactor.connect.concurrent.ExecutorExtension._
import com.datamountaineer.streamreactor.connect.concurrent.FutureAwaitWithFailFastFn
import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.hazelcast.config.{HazelCastSinkSettings, HazelCastStoreAsType, TargetType}
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.sink.SinkRecord

import scala.concurrent.duration._
import scala.util.{Failure, Success}


  def write(records: Seq[SinkRecord]): Unit = {
    if (records.isEmpty) {
      logger.debug("No records received.")
    } else {
      logger.debug(s"Received ${records.size} records.")
      if (settings.allowParallel) parallelWrite(records) else sequentialWrite(records)
      logger.debug(s"Written ${records.size}")
    }
  }

  def sequentialWrite(records: Seq[SinkRecord]): Any = {
    try {
      records.foreach(r => insert(r))
    } catch {
      case t: Throwable =>
        logger.error(s"There was an error inserting the records ${t.getMessage}", t)
        handleTry(Failure(t))
    }
  }

  def parallelWrite(records: Seq[SinkRecord]): Any = {
    logger.warn("Running parallel writes! Order of writes not guaranteed.")
    val executor = Executors.newFixedThreadPool(settings.threadPoolSize)

    try {
      val futures = records.map { record =>
        executor.submit {
          insert(record)
          ()
        }
      }

      //when the call returns the pool is shutdown
      FutureAwaitWithFailFastFn(executor, futures, 1.hours)
      handleTry(Success(()))
      logger.debug(s"Processed ${futures.size} records.")
    }
    catch {
      case t: Throwable =>
        logger.error(s"There was an error inserting the records ${t.getMessage}", t)
        handleTry(Failure(t))
    }
  }

  def insert(record: SinkRecord): Unit = {
    val writer = writers.get(record.topic())
    writer.foreach(w => w.write(record))
  }

  def close(): Unit = {
    logger.info("Shutting down Hazelcast client.")
    writers.values.foreach(_.close)
    settings.client.shutdown()
  }

  def flush(): Unit = {}
} 
Example 22
Source File: HazelCastSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hazelcast.sink

import java.util

import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.hazelcast.config.{HazelCastSinkConfig, HazelCastSinkConfigConstants, HazelCastSinkSettings}
import com.datamountaineer.streamreactor.connect.hazelcast.writers.HazelCastWriter
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._


  override def stop(): Unit = {
    logger.info("Stopping Hazelcast sink.")
    writer.foreach(w => w.close())
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    writer.foreach(w => w.flush())
  }

  override def version: String = manifest.version()
} 
Example 23
Source File: StructFieldsExtractor.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.voltdb

import java.text.SimpleDateFormat
import java.util.TimeZone

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Field, Struct, _}

import scala.collection.JavaConverters._

trait FieldsValuesExtractor {
  def get(struct: Struct): Map[String, Any]
}

case class StructFieldsExtractor(targetTable: String,
                                 includeAllFields: Boolean,
                                 fieldsAliasMap: Map[String, String],
                                 isUpsert: Boolean = false) extends FieldsValuesExtractor with StrictLogging {
  require(targetTable != null && targetTable.trim.length > 0)

  def get(struct: Struct): Map[String, Any] = {
    val schema = struct.schema()
    val fields: Seq[Field] = {
      if (includeAllFields) {
        schema.fields().asScala
      } else {
        val selectedFields = schema.fields().asScala.filter(f => fieldsAliasMap.contains(f.name()))
        val diffSet = fieldsAliasMap.keySet.diff(selectedFields.map(_.name()).toSet)
        if (diffSet.nonEmpty) {
          val errMsg = s"Following columns ${diffSet.mkString(",")} have not been found. Available columns:${fieldsAliasMap.keys.mkString(",")}"
          logger.error(errMsg)
          sys.error(errMsg)
        }
        selectedFields
      }
    }

    //need to select all fields including null. the stored proc needs a fixed set of params
    fields.map { field =>
      val schema = field.schema()
      val value = Option(struct.get(field))
        .map { value =>
          //handle specific schema
          schema.name() match {
            case Decimal.LOGICAL_NAME =>
              value.asInstanceOf[Any] match {
                case _:java.math.BigDecimal => value
                case arr: Array[Byte] => Decimal.toLogical(schema, arr)
                case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
              }
            case Time.LOGICAL_NAME =>
              value.asInstanceOf[Any] match {
                case i: Int => StructFieldsExtractor.TimeFormat.format(Time.toLogical(schema, i))
                case d:java.util.Date => StructFieldsExtractor.TimeFormat.format(d)
                case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
              }

            case Timestamp.LOGICAL_NAME =>
              value.asInstanceOf[Any] match {
                case d:java.util.Date => StructFieldsExtractor.DateFormat.format(d)
                case l: Long => StructFieldsExtractor.DateFormat.format(Timestamp.toLogical(schema, l))
                case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
              }

            case _ => value
          }
        }.orNull

      fieldsAliasMap.getOrElse(field.name(), field.name()) -> value
    }.toMap
  }
}


object StructFieldsExtractor {
  val DateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
  val TimeFormat: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss.SSSZ")
  DateFormat.setTimeZone(TimeZone.getTimeZone("UTC"))
} 
Example 24
Source File: VoltDbMetadataReader.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.voltdb.writers

import com.typesafe.scalalogging.StrictLogging
import org.voltdb.client.Client
import org.voltdb.{VoltTable, VoltType}

object VoltDbMetadataReader extends StrictLogging {

  def getProcedureParameters(client: Client, tableName: String): List[String] = {
    val rs = getMetadata(client, "COLUMNS")
    val params = rs.flatMap { vt =>
      vt.advanceRow()
      val nbrRows = vt.getRowCount
      (0 until nbrRows).map(vt.fetchRow)
        .filter(_.getString("TABLE_NAME").toLowerCase == tableName.toLowerCase)
        .map(row => row.getString("COLUMN_NAME") -> row.get("ORDINAL_POSITION", VoltType.INTEGER).asInstanceOf[Int])
    }
      .sortBy { case (_, ordinal) => ordinal }
      .map { case (column, _) => column }
      .toList

    if (params.isEmpty) logger.error(s"Unable to find parameters for table $tableName in Voltdb")

    params
  }

  private def getMetadata(client: Client, metadata: String): Array[VoltTable] = {
    client.callProcedure("@SystemCatalog", metadata).getResults
  }


} 
Example 25
Source File: VoltDbWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.voltdb.writers

import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.datamountaineer.streamreactor.connect.sink.DbWriter
import com.datamountaineer.streamreactor.connect.voltdb.config.VoltSettings
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.Struct
import org.apache.kafka.connect.sink.SinkRecord
import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException
import org.voltdb.client.{ClientConfig, ClientFactory}

import scala.util.Try

class VoltDbWriter(settings: VoltSettings) extends DbWriter with StrictLogging with ConverterUtil with ErrorHandler {

  //ValidateStringParameterFn(settings.servers, "settings")
  //ValidateStringParameterFn(settings.user, "settings")

  //initialize error tracker
  initialize(settings.maxRetries, settings.errorPolicy)

  private val voltConfig = new ClientConfig(settings.user, settings.password)
  private val client = ClientFactory.createClient(voltConfig)
  VoltConnectionConnectFn(client, settings)

  private val proceduresMap = settings.fieldsExtractorMap.values.map { extract =>
    val procName = s"${extract.targetTable}.${if (extract.isUpsert) "upsert" else "insert"}"
    logger.info(s"Retrieving the metadata for $procName ...")
    val fields = VoltDbMetadataReader.getProcedureParameters(client, extract.targetTable).map(_.toUpperCase)
    logger.info(s"$procName expected arguments are: ${fields.mkString(",")}")
    extract.targetTable -> ProcAndFields(procName, fields)
  }.toMap

  override def write(records: Seq[SinkRecord]): Unit = {
    if (records.isEmpty) {
      logger.debug("No records received.")
    } else {
      val t = Try(records.withFilter(_.value() != null).foreach(insert))
      t.foreach(_ => logger.info("Writing complete"))
      handleTry(t)
    }
  }

  private def insert(record: SinkRecord) = {
    require(record.value().getClass == classOf[Struct], "Only Struct payloads are handled")
    val extractor = settings.fieldsExtractorMap.getOrElse(record.topic(),
      throw new ConfigException(s"${record.topic()} is not handled by the configuration:${settings.fieldsExtractorMap.keys.mkString(",")}"))

    val fieldsAndValuesMap = extractor.get(record.value().asInstanceOf[Struct]).map { case (k, v) => (k.toUpperCase, v) }
    logger.info(fieldsAndValuesMap.mkString(","))
    val procAndFields: ProcAndFields = proceduresMap(extractor.targetTable)
    //get the list of arguments to pass to the table insert/upsert procedure. if the procedure expects a field and is
    //not present in the incoming SinkRecord it would use null
    //No table evolution is supported yet

    val arguments: Array[String] = PrepareProcedureFieldsFn(procAndFields.fields, fieldsAndValuesMap).toArray
    logger.info(s"Calling procedure:${procAndFields.procName} with parameters:${procAndFields.fields.mkString(",")} with arguments:${arguments.mkString(",")}")

    client.callProcedure(procAndFields.procName, arguments: _*)
  }

  override def close(): Unit = client.close()

  private case class ProcAndFields(procName: String, fields: Seq[String])

} 
Example 26
Source File: CreateSqlFn.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.voltdb.writers

import com.typesafe.scalalogging.StrictLogging

object CreateSqlFn extends StrictLogging {
  def apply(targetTable: String, isUpsert: Boolean, columns: Seq[String]): String = {
    val sql =
      s"""
         |${if (isUpsert) "UPSERT" else "INSERT"} INTO $targetTable (${columns.mkString(",")})
         |VALUES (${columns.map(_ => "?").mkString(",")})
    """.stripMargin
    logger.debug(sql)
    sql
  }
} 
Example 27
Source File: HiveSourceConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.source

import java.util

import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.landoop.streamreactor.connect.hive.sink.config.HiveSinkConfigDef
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.source.SourceConnector

import scala.collection.JavaConverters._

class HiveSourceConnector extends SourceConnector with StrictLogging {

  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)
  private var props: util.Map[String, String] = _

  override def version(): String = manifest.version()
  override def taskClass(): Class[_ <: Task] = classOf[HiveSourceTask]
  override def config(): ConfigDef = HiveSinkConfigDef.config

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(s"Creating hive sink connector")
    this.props = props
  }

  override def stop(): Unit = ()

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Creating $maxTasks tasks config")
    List.fill(maxTasks)(props).asJava
  }
} 
Example 28
Source File: SendgridEmailService.scala    From scala-clippy   with Apache License 2.0 5 votes vote down vote up
package util.email

import com.sendgrid.SendGrid
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.Future
import scala.util.Properties

class SendgridEmailService(sendgridUsername: String, sendgridPassword: String, emailFrom: String)
    extends EmailService
    with StrictLogging {

  private lazy val sendgrid = new SendGrid(sendgridUsername, sendgridPassword)

  override def send(to: String, subject: String, body: String) = {
    val email = new SendGrid.Email()
    email.addTo(to)
    email.setFrom(emailFrom)
    email.setSubject(subject)
    email.setText(body)

    val response = sendgrid.send(email)
    if (response.getStatus) {
      logger.info(s"Email to $to sent")
    } else {
      logger.error(
        s"Email to $to, subject: $subject, body: $body, not sent: " +
          s"${response.getCode}/${response.getMessage}"
      )
    }

    Future.successful(())
  }
}

object SendgridEmailService extends StrictLogging {
  def createFromEnv(emailFrom: String): Option[SendgridEmailService] =
    for {
      u <- Properties.envOrNone("SENDGRID_USERNAME")
      p <- Properties.envOrNone("SENDGRID_PASSWORD")
    } yield {
      logger.info("Using SendGrid email service")
      new SendgridEmailService(u, p, emailFrom)
    }
} 
Example 29
Source File: SchemaRegistry.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.schemas

import com.typesafe.scalalogging.StrictLogging
import io.confluent.kafka.schemaregistry.client.rest.RestService

import scala.util.{Failure, Success, Try}
import scala.collection.JavaConverters._


  def getSubjects(url: String) : List[String] = {
    val registry = new RestService(url)
    val schemas: List[String] = Try(registry.getAllSubjects.asScala.toList) match {
      case Success(s) => s
      case Failure(f) => {
        logger.warn("Unable to connect to the Schema registry. An attempt will be made to create the table" +
          " on receipt of the first records.")
        List.empty[String]
      }
    }

    schemas.foreach(s=>logger.info(s"Found schemas for $s"))
    schemas
  }
} 
Example 30
Source File: ErrorPolicy.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.errors

import java.util.Date

import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum.ErrorPolicyEnum
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.errors.RetriableException


object ErrorPolicyEnum extends Enumeration {
  type ErrorPolicyEnum = Value
  val NOOP, THROW, RETRY = Value
}

case class ErrorTracker(retries: Int, maxRetries: Int, lastErrorMessage: String, lastErrorTimestamp: Date, policy: ErrorPolicy)

trait ErrorPolicy extends StrictLogging {
  def handle(error: Throwable, sink: Boolean = true, retryCount: Int = 0)
}

object ErrorPolicy extends StrictLogging {
  def apply(policy: ErrorPolicyEnum): ErrorPolicy = {
    policy match {
      case ErrorPolicyEnum.NOOP => NoopErrorPolicy()
      case ErrorPolicyEnum.THROW => ThrowErrorPolicy()
      case ErrorPolicyEnum.RETRY => RetryErrorPolicy()
    }
  }
}

case class NoopErrorPolicy() extends ErrorPolicy {
  override def handle(error: Throwable, sink: Boolean = true, retryCount: Int = 0){
    logger.warn(s"Error policy NOOP: ${error.getMessage}. Processing continuing.")
  }
}

case class ThrowErrorPolicy() extends ErrorPolicy {
  override def handle(error: Throwable, sink: Boolean = true, retryCount: Int = 0){
    throw new RuntimeException(error)
  }
}

case class RetryErrorPolicy() extends ErrorPolicy {

  override def handle(error: Throwable, sink: Boolean = true, retryCount: Int) = {
    if (retryCount == 0) {
      throw new RuntimeException(error)
    }
    else {
      logger.warn(s"Error policy set to RETRY. Remaining attempts $retryCount")
      throw new RetriableException(error)
    }
  }
} 
Example 31
Source File: ErrorHandler.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.errors

import java.text.SimpleDateFormat
import java.util.Date

import com.typesafe.scalalogging.StrictLogging

import scala.util.{Failure, Success, Try}


trait ErrorHandler extends StrictLogging {
  var errorTracker: Option[ErrorTracker] = None
  private val dateFormatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS'Z'")

  def initialize(maxRetries: Int, errorPolicy: ErrorPolicy): Unit = {
    errorTracker = Some(ErrorTracker(maxRetries, maxRetries, "", new Date(), errorPolicy))
  }

  def getErrorTrackerRetries() : Int = {
    errorTracker.get.retries
  }

  def errored() : Boolean = {
    errorTracker.get.retries != errorTracker.get.maxRetries
  }

  def handleTry[A](t : Try[A]) : Option[A] = {
    require(errorTracker.isDefined, "ErrorTracker is not set call. Initialize.")
    t
    match {
      case Success(s) => {
        //success, check if we had previous errors.
        if (errorTracker.get.retries != errorTracker.get.maxRetries) {
          logger.info(s"Recovered from error ${errorTracker.get.lastErrorMessage} at " +
            s"${dateFormatter.format(errorTracker.get.lastErrorTimestamp)}")
        }
        //cleared error
        resetErrorTracker()
        Some(s)
      }
      case Failure(f) =>
        //decrement the retry count
        logger.error(s"Encountered error ${f.getMessage}", f)
        this.errorTracker = Some(decrementErrorTracker(errorTracker.get, f.getMessage))
        handleError(f, errorTracker.get.retries, errorTracker.get.policy)
        None
    }
  }

  def resetErrorTracker() = {
    errorTracker = Some(ErrorTracker(errorTracker.get.maxRetries, errorTracker.get.maxRetries, "", new Date(),
      errorTracker.get.policy))
  }

  private def decrementErrorTracker(errorTracker: ErrorTracker, msg: String): ErrorTracker = {
    if (errorTracker.maxRetries == -1) {
      ErrorTracker(errorTracker.retries, errorTracker.maxRetries, msg, new Date(), errorTracker.policy)
    } else {
      ErrorTracker(errorTracker.retries - 1, errorTracker.maxRetries, msg, new Date(), errorTracker.policy)
    }
  }

  private def handleError(f: Throwable, retries: Int, policy: ErrorPolicy): Unit = {
    policy.handle(f, true, retries)
  }
} 
Example 32
Source File: Helpers.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.config

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigException


  def tableTopicParser(input: String) : Map[String, String] = {
    input.split(",")
      .toList
      .map(c => c.split(":"))
      .map(a => {if (a.length == 1) (a(0), a(0)) else (a(0), a(1)) }).toMap
  }


  def checkInputTopics(kcqlConstant: String, props: Map[String, String]) = {
    val topics = props.get("topics").get.split(",").map(t => t.trim).toSet
    val raw = props.get(kcqlConstant).get
    if (raw.isEmpty) {
      throw new ConfigException(s"Missing $kcqlConstant")
    }
    val kcql = raw.split(";").map(r => Kcql.parse(r)).toSet
    val sources = kcql.map(k => k.getSource)
    val res = topics.subsetOf(sources)

    if (!res) {
      val missing = topics.diff(sources)
      throw new ConfigException(s"Mandatory `topics` configuration contains topics not set in $kcqlConstant: ${missing}, kcql contains $sources")
    }

    val res1 = sources.subsetOf(topics)

    if (!res1) {
      val missing = topics.diff(sources)
      throw new ConfigException(s"$kcqlConstant configuration contains topics not set in mandatory `topic` configuration: ${missing}, kcql contains $sources")
    }

    true
  }
} 
Example 33
Source File: ExponentialBackOffHandler.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.source

import java.time.Duration

import com.typesafe.scalalogging.StrictLogging


class ExponentialBackOffHandler(name: String, step: Duration, cap: Duration) extends StrictLogging  {
  private var backoff = new ExponentialBackOff(step, cap)

  def ready = backoff.passed

  def failure = {
    backoff = backoff.nextFailure
    logger.info(s"$name: Next poll will be around ${backoff.endTime}")
  }

  def success = {
    backoff = backoff.nextSuccess
    logger.info(s"$name: Backing off. Next poll will be around ${backoff.endTime}")
  }

  def update(status: Boolean): Unit = {
    if (status) {
      success
    } else {
      failure
    }
  }

  def remaining = backoff.remaining
} 
Example 34
Source File: FutureAwaitWithFailFastFn.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.concurrent

import java.util.concurrent.{ExecutorService, TimeUnit}

import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.{Await, Future, Promise}
import scala.util.Failure

object FutureAwaitWithFailFastFn extends StrictLogging {

  def apply(executorService: ExecutorService, futures: Seq[Future[Unit]], duration: Duration): Unit = {
    //make sure we ask the executor to shutdown to ensure the process exits
    executorService.shutdown()

    val promise = Promise[Boolean]()

    //stop on the first failure
    futures.foreach { f =>
      f.failed.foreach { case t =>
        if (promise.tryFailure(t)) {
          executorService.shutdownNow()
        }
      }
    }

    val fut = Future.sequence(futures)
    fut.foreach { case t =>
      if (promise.trySuccess(true)) {
        val failed = executorService.shutdownNow()
        if (failed.size() > 0) {
          logger.error(s"${failed.size()} task have failed.")
        }
      }
    }

    Await.ready(promise.future, duration).value match {
      case Some(Failure(t)) =>
        executorService.awaitTermination(1, TimeUnit.MINUTES)
        //throw the underlying error
        throw t

      case _ =>
        executorService.awaitTermination(1, TimeUnit.MINUTES)
    }
  }

  def apply[T](executorService: ExecutorService, futures: Seq[Future[T]], duration: Duration = 1.hours): Seq[T] = {
    //make sure we ask the executor to shutdown to ensure the process exits
    executorService.shutdown()

    val promise = Promise[Boolean]()

    //stop on the first failure
    futures.foreach { f =>
      f.failed.foreach { case t =>
        if (promise.tryFailure(t)) {
          executorService.shutdownNow()
        }
      }
    }

    val fut = Future.sequence(futures)
    fut.foreach { case t =>
      if (promise.trySuccess(true)) {
        val failed = executorService.shutdownNow()
        if (failed.size() > 0) {
          logger.error(s"${failed.size()} task have failed.")
        }
      }
    }

    Await.ready(promise.future, duration).value match {
      case Some(Failure(t)) =>
        executorService.awaitTermination(1, TimeUnit.MINUTES)
        //throw the underlying error
        throw t

      case _ =>
        executorService.awaitTermination(1, TimeUnit.MINUTES)
        //return the result from each of the futures
        Await.result(Future.sequence(futures), 1.minute)
    }
  }
} 
Example 35
Source File: ExecutorIdExtenderPluginTest.scala    From marathon-example-plugins   with Apache License 2.0 5 votes vote down vote up
package mesosphere.marathon.example.plugin.executorid

import com.typesafe.scalalogging.StrictLogging
import org.apache.mesos.Protos.Environment.Variable
import org.apache.mesos.Protos._
import org.scalatest.{GivenWhenThen, Matchers, WordSpec}

class ExecutorIdExtenderPluginTest extends WordSpec with Matchers with GivenWhenThen with StrictLogging {

  "Given an MARATHON_EXECUTOR_ID label an executorID should be injected" in {
    val f = new Fixture

    Given("a TaskInfo with a MARATHON_EXECUTOR_ID label")
    val taskInfo = TaskInfo.newBuilder.
      setExecutor(ExecutorInfo.newBuilder.
          setCommand(CommandInfo.newBuilder.
            setEnvironment(Environment.newBuilder.addVariables(
                Variable.newBuilder.setName("foo").setValue("bar")
            )
          )).
        setExecutorId(ExecutorID.newBuilder.setValue("task.12345"))
      ).
      setLabels(Labels.newBuilder.addLabels(Label.newBuilder.
        setKey(f.plugin.ExecutorIdLabel)
          .setValue("customer-executor-id")
      ))

    When("handled by the plugin")
    f.plugin.taskInfo(null, taskInfo)

    Then("ExecutorInfo.ExecutorId should be changed")
    taskInfo.getExecutor.getExecutorId.getValue shouldBe "customer-executor-id"

    And("Environment variables should be removed")
    taskInfo.getExecutor.getCommand.getEnvironment.getVariablesCount shouldBe 0
  }

  "Given no MARATHON_EXECUTOR_ID label an executorID should be untouched" in {
    val f = new Fixture

    Given("a TaskInfo with a MARATHON_EXECUTOR_ID label")
    val taskInfo = TaskInfo.newBuilder.
      setExecutor(ExecutorInfo.newBuilder.
        setCommand(CommandInfo.newBuilder.
          setEnvironment(Environment.newBuilder.addVariables(
            Variable.newBuilder.setName("foo").setValue("bar")
          )
          )).
        setExecutorId(ExecutorID.newBuilder.setValue("task.12345"))
      ).
      setLabels(Labels.newBuilder.addLabels(Label.newBuilder.
        setKey("baz")
        .setValue("wof")
      ))

    When("handled by the plugin")
    f.plugin.taskInfo(null, taskInfo)

    Then("ExecutorInfo.ExecutorId should stay the same")
    taskInfo.getExecutor.getExecutorId.getValue shouldBe "task.12345"

    And("environment variables should be kept")
    taskInfo.getExecutor.getCommand.getEnvironment.getVariablesCount shouldBe 1
  }

  class Fixture {
    val plugin = new ExecutorIdExtenderPlugin()
  }
} 
Example 36
Source File: ExecutorIdExtenderPlugin.scala    From marathon-example-plugins   with Apache License 2.0 5 votes vote down vote up
package mesosphere.marathon.example.plugin.executorid

import com.typesafe.scalalogging.StrictLogging
import mesosphere.marathon.plugin.{ApplicationSpec, PodSpec}
import mesosphere.marathon.plugin.plugin.PluginConfiguration
import mesosphere.marathon.plugin.task.RunSpecTaskProcessor
import org.apache.mesos.Protos._
import play.api.libs.json.JsObject

import scala.collection.JavaConverters._

class ExecutorIdExtenderPlugin extends RunSpecTaskProcessor with PluginConfiguration with StrictLogging {

  val ExecutorIdLabel = "MARATHON_EXECUTOR_ID"

  override def taskInfo(appSpec: ApplicationSpec, builder: TaskInfo.Builder): Unit = {
    // If custom executor is used
    if (builder.hasExecutor && builder.getExecutor.hasCommand) {
      val labels = builder.getLabels.getLabelsList.asScala

      // ... and there is MARATHON_EXECUTOR_ID label set
      labels.find(_.getKey == ExecutorIdLabel).foreach {label =>
        // Set the executorID from the MARATHON_EXECUTOR_ID label
        val executorId = label.getValue
        val executorBuilder = builder.getExecutor.toBuilder
        executorBuilder.setExecutorId(ExecutorID.newBuilder.setValue(executorId))

        // An executor id of the executor to launch this application. Note that all application sharing the same
        // executor id will share the same executor instance allowing to save resources. The downfall is that all
        // the apps started with the same executo id must have identical `TaskInfo.ExecutorInfo`. Among other things that
        // means environment variables must be identical. Since marathon would automatically generate per-task environment
        // variables like `MARATHON_APP_VERSION`, `MESOS_TASK_ID` or `PORTx` this will not work.
        // For this reason we just remove all the environment variables. It is possible to be more selective and remove
        // only those environment variables that change from task to task but that's too much hustle for this simple plugin.
        val commandBuilder = executorBuilder.getCommand.toBuilder
        commandBuilder.clearEnvironment()
        executorBuilder.setCommand(commandBuilder)

        builder.setExecutor(executorBuilder)
      }
    }
  }

  override def taskGroup(podSpec: PodSpec, executor: ExecutorInfo.Builder, taskGroup: TaskGroupInfo.Builder): Unit = {}

  override def initialize(marathonInfo: Map[String, Any], configuration: JsObject): Unit = {
    logger.info(s"ExecutorIdExtenderPlugin successfully initialized")
  }
} 
Example 37
Source File: ExampleIdentity.scala    From marathon-example-plugins   with Apache License 2.0 5 votes vote down vote up
package mesosphere.marathon.example.plugin.auth

import com.typesafe.scalalogging.StrictLogging
import mesosphere.marathon.plugin.auth._
import play.api.libs.functional.syntax._
import play.api.libs.json._


case class ExampleIdentity(username: String, password: String, permissions: Seq[Permission]) extends Identity with StrictLogging {

  def isAllowed[R](action: AuthorizedAction[R], resource: R): Boolean = {
    val permit = permissions.find { permission =>
      permission.eligible(action) && permission.isAllowed(resource)
    }
    permit match {
      case Some(p) => logger.info(s"Found permit: $p")
      case None    => logger.error(s"$username is not allowed for action $action on resource $resource")
    }
    permit.isDefined
  }
}

object ExampleIdentity {
  implicit val identityRead: Reads[ExampleIdentity] = (
    (__ \ "user").read[String] ~
    (__ \ "password").read[String] ~
    (__ \ "permissions").read[Seq[PathPermission]]
  ) ((name, pass, permissions) => ExampleIdentity(name, pass, permissions))
} 
Example 38
Source File: AkkaHttpServerTests.scala    From tapir   with Apache License 2.0 5 votes vote down vote up
package sttp.tapir.server.akkahttp

import cats.implicits._
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.server.Directives._
import cats.data.NonEmptyList
import cats.effect.{IO, Resource}
import sttp.client._
import com.typesafe.scalalogging.StrictLogging
import sttp.tapir.{Endpoint, endpoint, stringBody}
import sttp.tapir.server.tests.ServerTests
import sttp.tapir._
import sttp.tapir.server.{DecodeFailureHandler, ServerDefaults, ServerEndpoint}
import sttp.tapir.tests.{Port, PortCounter}

import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.reflect.ClassTag

class AkkaHttpServerTests extends ServerTests[Future, AkkaStream, Route] with StrictLogging {
  private implicit var actorSystem: ActorSystem = _

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    actorSystem = ActorSystem()
  }

  override protected def afterAll(): Unit = {
    Await.result(actorSystem.terminate(), 5.seconds)
    super.afterAll()
  }

  override def route[I, E, O](
      e: ServerEndpoint[I, E, O, AkkaStream, Future],
      decodeFailureHandler: Option[DecodeFailureHandler] = None
  ): Route = {
    implicit val serverOptions: AkkaHttpServerOptions = AkkaHttpServerOptions.default.copy(
      decodeFailureHandler = decodeFailureHandler.getOrElse(ServerDefaults.decodeFailureHandler)
    )
    e.toRoute
  }

  override def routeRecoverErrors[I, E <: Throwable, O](e: Endpoint[I, E, O, AkkaStream], fn: I => Future[O])(implicit
      eClassTag: ClassTag[E]
  ): Route = {
    e.toRouteRecoverErrors(fn)
  }

  override def server(routes: NonEmptyList[Route], port: Port): Resource[IO, Unit] = {
    val bind = IO.fromFuture(IO(Http().bindAndHandle(routes.toList.reduce(_ ~ _), "localhost", port)))
    Resource.make(bind)(binding => IO.fromFuture(IO(binding.unbind())).void).void
  }

  override def pureResult[T](t: T): Future[T] = Future.successful(t)
  override def suspendResult[T](t: => T): Future[T] = {
    import scala.concurrent.ExecutionContext.Implicits.global
    Future { t }
  }

  override lazy val portCounter: PortCounter = new PortCounter(57000)

  if (testNameFilter.isEmpty) {
    test("endpoint nested in a path directive") {
      val e = endpoint.get.in("test" and "directive").out(stringBody).serverLogic(_ => pureResult("ok".asRight[Unit]))
      val port = portCounter.next()
      val route = Directives.pathPrefix("api")(e.toRoute)
      server(NonEmptyList.of(route), port).use { _ =>
        basicRequest.get(uri"http://localhost:$port/api/test/directive").send().map(_.body shouldBe Right("ok"))
      }.unsafeRunSync
    }
  }
} 
Example 39
Source File: SqlDatabase.scala    From scala-clippy   with Apache License 2.0 5 votes vote down vote up
package util

import java.net.URI

import com.typesafe.config.ConfigValueFactory._
import com.typesafe.config.{Config, ConfigFactory}
import com.typesafe.scalalogging.StrictLogging
import org.flywaydb.core.Flyway
import slick.driver.JdbcProfile
import slick.jdbc.JdbcBackend._

case class SqlDatabase(
    db: slick.jdbc.JdbcBackend#Database,
    driver: JdbcProfile,
    connectionString: JdbcConnectionString
) {
  def updateSchema() {
    val flyway = new Flyway()
    flyway.setDataSource(connectionString.url, connectionString.username, connectionString.password)
    flyway.migrate()
  }

  def close() {
    db.close()
  }
}

case class JdbcConnectionString(url: String, username: String = "", password: String = "")

object SqlDatabase extends StrictLogging {

  def create(config: DatabaseConfig): SqlDatabase = {
    val envDatabaseUrl = System.getenv("DATABASE_URL")

    if (config.dbPostgresServerName.length > 0)
      createPostgresFromConfig(config)
    else if (envDatabaseUrl != null)
      createPostgresFromEnv(envDatabaseUrl)
    else
      createEmbedded(config)
  }

  def createEmbedded(connectionString: String): SqlDatabase = {
    val db = Database.forURL(connectionString)
    SqlDatabase(db, slick.driver.H2Driver, JdbcConnectionString(connectionString))
  }

  private def createPostgresFromEnv(envDatabaseUrl: String) = {
    import DatabaseConfig._
    
    val dbUri    = new URI(envDatabaseUrl)
    val username = dbUri.getUserInfo.split(":")(0)
    val password = dbUri.getUserInfo.split(":")(1)
    val intermediaryConfig = new DatabaseConfig {
      override def rootConfig: Config =
        ConfigFactory
          .empty()
          .withValue(PostgresDSClass, fromAnyRef("org.postgresql.ds.PGSimpleDataSource"))
          .withValue(PostgresServerNameKey, fromAnyRef(dbUri.getHost))
          .withValue(PostgresPortKey, fromAnyRef(dbUri.getPort))
          .withValue(PostgresDbNameKey, fromAnyRef(dbUri.getPath.tail))
          .withValue(PostgresUsernameKey, fromAnyRef(username))
          .withValue(PostgresPasswordKey, fromAnyRef(password))
          .withFallback(ConfigFactory.load())
    }
    createPostgresFromConfig(intermediaryConfig)
  }

  private def postgresUrl(host: String, port: String, dbName: String) =
    s"jdbc:postgresql://$host:$port/$dbName"

  private def postgresConnectionString(config: DatabaseConfig) = {
    val host     = config.dbPostgresServerName
    val port     = config.dbPostgresPort
    val dbName   = config.dbPostgresDbName
    val username = config.dbPostgresUsername
    val password = config.dbPostgresPassword
    JdbcConnectionString(postgresUrl(host, port, dbName), username, password)
  }

  private def createPostgresFromConfig(config: DatabaseConfig) = {
    val db = Database.forConfig("db.postgres", config.rootConfig)
    SqlDatabase(db, slick.driver.PostgresDriver, postgresConnectionString(config))
  }

  private def createEmbedded(config: DatabaseConfig): SqlDatabase = {
    val db = Database.forConfig("db.h2")
    SqlDatabase(db, slick.driver.H2Driver, JdbcConnectionString(embeddedConnectionStringFromConfig(config)))
  }

  private def embeddedConnectionStringFromConfig(config: DatabaseConfig): String = {
    val url      = config.dbH2Url
    val fullPath = url.split(":")(3)
    logger.info(s"Using an embedded database, with data files located at: $fullPath")
    url
  }
} 
Example 40
Source File: Util.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.slotting

import java.nio.ByteBuffer
import java.time.Duration
import java.util.concurrent.ScheduledFuture

import com.netflix.iep.NetflixEnvironment
import com.netflix.spectator.api.Registry
import com.netflix.spectator.impl.Scheduler
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging

object Util extends StrictLogging {

  def getLongOrDefault(config: Config, basePath: String): Long = {
    val env = NetflixEnvironment.accountEnv()
    val region = NetflixEnvironment.region()

    if (config.hasPath(s"$basePath.$env.$region"))
      config.getLong(s"$basePath.$env.$region")
    else
      config.getLong(s"$basePath.default")
  }

  def compress(s: String): ByteBuffer = {
    ByteBuffer.wrap(Gzip.compressString(s))
  }

  def decompress(buf: ByteBuffer): String = {
    Gzip.decompressString(toByteArray(buf))
  }

  def toByteArray(buf: ByteBuffer): Array[Byte] = {
    val bytes = new Array[Byte](buf.remaining)
    buf.get(bytes, 0, bytes.length)
    buf.clear()
    bytes
  }

  def startScheduler(
    registry: Registry,
    name: String,
    interval: Duration,
    fn: () => Unit
  ): ScheduledFuture[_] = {
    val scheduler = new Scheduler(registry, name, 2)
    val options = new Scheduler.Options()
      .withFrequency(Scheduler.Policy.FIXED_RATE_SKIP_IF_LONG, interval)
    scheduler.schedule(options, () => fn())
  }

} 
Example 41
Source File: DummyEmailService.scala    From scala-clippy   with Apache License 2.0 5 votes vote down vote up
package util.email

import com.typesafe.scalalogging.StrictLogging

import scala.collection.mutable.ListBuffer
import scala.concurrent.Future

class DummyEmailService extends EmailService with StrictLogging {
  private val sentEmails: ListBuffer[(String, String, String)] = ListBuffer()

  logger.info("Using dummy email service")

  def reset() {
    sentEmails.clear()
  }

  override def send(to: String, subject: String, body: String) = {
    this.synchronized {
      sentEmails.+=((to, subject, body))
    }

    logger.info(s"Would send email to $to, with subject: $subject, body: $body")
    Future.successful(())
  }

  def wasEmailSent(to: String, subject: String): Boolean =
    sentEmails.exists(email => email._1.contains(to) && email._2 == subject)
} 
Example 42
Source File: CwForwardingConfigSuite.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.iep.lwc.fwd.admin

import akka.actor.ActorSystem
import com.netflix.atlas.eval.stream.Evaluator
import com.netflix.spectator.api.NoopRegistry
import com.typesafe.config.ConfigFactory
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.funsuite.AnyFunSuite

class CwForwardingConfigSuite extends AnyFunSuite with CwForwardingTestConfig with StrictLogging {

  private val config = ConfigFactory.load()
  private val system = ActorSystem()

  val validations = new CwExprValidations(
    new ExprInterpreter(config),
    new Evaluator(config, new NoopRegistry(), system)
  )

  test("Skip the given checks") {
    val config = makeConfig(checksToSkip = List("DefaultDimension"))
    assert(config.shouldSkip("DefaultDimension"))
  }

  test("Do checks that are not flagged to skip") {
    val config = makeConfig(checksToSkip = List("DefaultDimension"))
    assert(!config.shouldSkip("SingleExpression"))
  }

} 
Example 43
Source File: Api.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.iep.lwc.fwd.admin

import akka.actor.ActorSystem
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives.entity
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.unmarshalling.Unmarshaller._
import com.fasterxml.jackson.databind.JsonNode
import com.netflix.atlas.akka.CustomDirectives._
import com.netflix.atlas.akka.WebApi
import com.netflix.atlas.json.Json
import com.netflix.iep.lwc.fwd.cw.ExpressionId
import com.netflix.iep.lwc.fwd.cw.Report
import com.netflix.spectator.api.Registry
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.Future

class Api(
  registry: Registry,
  schemaValidation: SchemaValidation,
  cwExprValidations: CwExprValidations,
  markerService: MarkerService,
  purger: Purger,
  exprDetailsDao: ExpressionDetailsDao,
  system: ActorSystem
) extends WebApi
    with StrictLogging {

  private implicit val configUnmarshaller =
    byteArrayUnmarshaller.map(Json.decode[JsonNode](_))

  private implicit val blockingDispatcher = system.dispatchers.lookup("blocking-dispatcher")

  def routes: Route = {

    endpointPath("api" / "v1" / "cw" / "check", Remaining) { key =>
      post {
        entity(as[JsonNode]) { json =>
          complete {
            schemaValidation.validate(key, json)
            cwExprValidations.validate(key, json)

            HttpResponse(StatusCodes.OK)
          }
        }
      }
    } ~
    endpointPath("api" / "v1" / "cw" / "report") {
      post {
        entity(as[JsonNode]) { json =>
          complete {
            Json
              .decode[List[Report]](json)
              .foreach { report =>
                val enqueued = markerService.queue.offer(report)
                if (!enqueued) {
                  logger.warn(s"Unable to queue report $report")
                }
              }
            HttpResponse(StatusCodes.OK)
          }
        }
      }
    } ~
    endpointPath("api" / "v1" / "cw" / "expr" / "purgeEligible") {
      get {
        parameter("events".as(CsvSeq[String])) { events =>
          complete {
            Future {
              val body = Json.encode(
                exprDetailsDao.queryPurgeEligible(
                  System.currentTimeMillis(),
                  events.toList
                )
              )

              HttpResponse(
                StatusCodes.OK,
                entity = HttpEntity(MediaTypes.`application/json`, body)
              )
            }
          }
        }
      }
    } ~
    endpointPath("api" / "v1" / "cw" / "expr" / "purge") {
      delete {
        entity(as[JsonNode]) { json =>
          complete {
            purger.purge(Json.decode[List[ExpressionId]](json))
          }
        }
      }
    }
  }

} 
Example 44
Source File: SchemaValidation.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.iep.lwc.fwd.admin

import com.fasterxml.jackson.databind.JsonNode
import com.github.fge.jsonschema.main.JsonSchema
import com.github.fge.jsonschema.main.JsonSchemaFactory
import com.netflix.atlas.json.Json
import com.typesafe.scalalogging.StrictLogging

import scala.io.Source
import scala.jdk.CollectionConverters._

class SchemaValidation extends StrictLogging {

  val schema: JsonSchema = {
    val reader = Source.fromResource("cw-fwding-cfg-schema.json").reader()
    try {
      JsonSchemaFactory
        .byDefault()
        .getJsonSchema(Json.decode[SchemaCfg](reader).schema)
    } finally {
      reader.close()
    }
  }

  def validate(key: String, json: JsonNode): Unit = {
    val pr = schema.validate(json)
    if (!pr.isSuccess) {
      throw new IllegalArgumentException(
        pr.asScala.map(_.getMessage).mkString("\n")
      )
    }
  }

}

case class SchemaCfg(schema: JsonNode, validationHook: String) 
Example 45
Source File: PropertiesLoader.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.iep.archaius

import akka.actor.Actor
import com.amazonaws.services.dynamodbv2.model.ScanRequest
import com.netflix.atlas.json.Json
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging

import scala.util.Failure
import scala.util.Success


class PropertiesLoader(config: Config, propContext: PropertiesContext, dynamoService: DynamoService)
    extends Actor
    with StrictLogging {

  private val table = config.getString("netflix.iep.archaius.table")

  import scala.concurrent.duration._
  import scala.concurrent.ExecutionContext.Implicits.global
  context.system.scheduler.schedule(5.seconds, 5.seconds, self, PropertiesLoader.Tick)

  def receive: Receive = {
    case PropertiesLoader.Tick =>
      val future = dynamoService.execute { client =>
        val matches = List.newBuilder[PropertiesApi.Property]
        val request = new ScanRequest().withTableName(table)
        var response = client.scan(request)
        matches ++= process(response.getItems)
        while (response.getLastEvaluatedKey != null) {
          request.setExclusiveStartKey(response.getLastEvaluatedKey)
          response = client.scan(request)
          matches ++= process(response.getItems)
        }
        matches.result()
      }

      future.onComplete {
        case Success(vs) => propContext.update(vs)
        case Failure(t)  => logger.error("failed to refresh properties from dynamodb", t)
      }
  }

  private def process(items: Items): PropList = {
    import scala.jdk.CollectionConverters._
    items.asScala
      .filter(_.containsKey("data"))
      .map(_.get("data").getS)
      .map(s => Json.decode[PropertiesApi.Property](s))
      .toList
  }
}

object PropertiesLoader {
  case object Tick
} 
Example 46
Source File: FileUtil.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.persistence

import java.io.File
import java.nio.file.Files

import com.netflix.atlas.core.util.Streams
import com.typesafe.scalalogging.StrictLogging

import scala.jdk.StreamConverters._

object FileUtil extends StrictLogging {

  def delete(f: File): Unit = {
    try {
      Files.delete(f.toPath)
      logger.debug(s"deleted file $f")
    } catch {
      case e: Exception => logger.error(s"failed to delete path $f", e)
    }
  }

  def listFiles(f: File): List[File] = {
    try {
      Streams.scope(Files.list(f.toPath)) { dir =>
        dir.toScala(List).map(_.toFile)
      }
    } catch {
      case e: Exception =>
        logger.error(s"failed to list files for: $f", e)
        Nil
    }
  }

  def isTmpFile(f: File): Boolean = {
    f.getName.endsWith(RollingFileWriter.TmpFileSuffix)
  }

} 
Example 47
Source File: LocalFilePersistService.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.persistence

import akka.Done
import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Flow
import akka.stream.scaladsl.Keep
import akka.stream.scaladsl.RestartFlow
import akka.stream.scaladsl.Sink
import com.netflix.atlas.akka.StreamOps
import com.netflix.atlas.akka.StreamOps.SourceQueue
import com.netflix.atlas.core.model.Datapoint
import com.netflix.iep.service.AbstractService
import com.netflix.spectator.api.Registry
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import javax.inject.Inject
import javax.inject.Singleton

import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.duration.Duration

@Singleton
class LocalFilePersistService @Inject()(
  val config: Config,
  val registry: Registry,
  // S3CopyService is actually NOT used by this service, it is here just to guarantee that the
  // shutdown callback (stopImpl) of this service is invoked before S3CopyService's
  val s3CopyService: S3CopyService,
  implicit val system: ActorSystem
) extends AbstractService
    with StrictLogging {
  implicit val ec = scala.concurrent.ExecutionContext.global
  implicit val mat = ActorMaterializer()

  private val queueSize = config.getInt("atlas.persistence.queue-size")

  private val fileConfig = config.getConfig("atlas.persistence.local-file")
  private val dataDir = fileConfig.getString("data-dir")
  private val maxRecords = fileConfig.getLong("max-records")
  private val maxDurationMs = fileConfig.getDuration("max-duration").toMillis
  private val maxLateDurationMs = fileConfig.getDuration("max-late-duration").toMillis
  private val rollingConf = RollingConfig(maxRecords, maxDurationMs, maxLateDurationMs)

  require(queueSize > 0)
  require(maxRecords > 0)
  require(maxDurationMs > 0)

  private var queue: SourceQueue[Datapoint] = _
  private var flowComplete: Future[Done] = _

  override def startImpl(): Unit = {
    logger.info("Starting service")
    val (q, f) = StreamOps
      .blockingQueue[Datapoint](registry, "LocalFilePersistService", queueSize)
      .via(getRollingFileFlow)
      .toMat(Sink.ignore)(Keep.both)
      .run
    queue = q
    flowComplete = f
  }

  private def getRollingFileFlow(): Flow[Datapoint, NotUsed, NotUsed] = {
    import scala.concurrent.duration._
    RestartFlow.withBackoff(
      minBackoff = 1.second,
      maxBackoff = 3.seconds,
      randomFactor = 0,
      maxRestarts = -1
    ) { () =>
      Flow.fromGraph(
        new RollingFileFlow(dataDir, rollingConf, registry)
      )
    }
  }

  // This service should stop the Akka flow when application is shutdown gracefully, and let
  // S3CopyService do the cleanup. It should trigger:
  //   1. stop taking more data points (monitor droppedQueueClosed)
  //   2. close current file writer so that last file is ready to copy to s3
  override def stopImpl(): Unit = {
    logger.info("Stopping service")
    queue.complete()
    Await.result(flowComplete, Duration.Inf)
    logger.info("Stopped service")
  }

  def persist(dp: Datapoint): Unit = {
    queue.offer(dp)
  }
} 
Example 48
Source File: S3CopyService.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.persistence

import java.io.File
import java.nio.file.Files
import java.nio.file.Paths

import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.KillSwitch
import akka.stream.KillSwitches
import akka.stream.scaladsl.Keep
import akka.stream.scaladsl.Source
import com.netflix.atlas.core.util.Streams
import com.netflix.iep.service.AbstractService
import com.netflix.spectator.api.Registry
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import javax.inject.Inject
import javax.inject.Singleton

import scala.concurrent.duration._

@Singleton
class S3CopyService @Inject()(
  val config: Config,
  val registry: Registry,
  implicit val system: ActorSystem
) extends AbstractService
    with StrictLogging {

  private val dataDir = config.getString("atlas.persistence.local-file.data-dir")

  private implicit val mat = ActorMaterializer()

  private var killSwitch: KillSwitch = _
  private val s3Config = config.getConfig("atlas.persistence.s3")

  private val cleanupTimeoutMs = s3Config.getDuration("cleanup-timeout").toMillis
  private val maxInactiveMs = s3Config.getDuration("max-inactive-duration").toMillis
  private val maxFileDurationMs =
    config.getDuration("atlas.persistence.local-file.max-duration").toMillis

  require(
    maxInactiveMs > maxFileDurationMs,
    "`max-inactive-duration` MUST be longer than `max-duration`, otherwise file may be renamed before normal write competes"
  )

  override def startImpl(): Unit = {
    logger.info("Starting service")
    killSwitch = Source
      .tick(1.second, 5.seconds, NotUsed)
      .viaMat(KillSwitches.single)(Keep.right)
      .flatMapMerge(Int.MaxValue, _ => Source(FileUtil.listFiles(new File(dataDir))))
      .toMat(new S3CopySink(s3Config, registry, system))(Keep.left)
      .run()
  }

  override def stopImpl(): Unit = {
    logger.info("Stopping service")
    waitForCleanup()
    if (killSwitch != null) killSwitch.shutdown()
  }

  private def waitForCleanup(): Unit = {
    logger.info("Waiting for cleanup")
    val start = System.currentTimeMillis
    while (hasMoreFiles) {
      if (System.currentTimeMillis() > start + cleanupTimeoutMs) {
        logger.error("Cleanup timeout")
        return
      }
      Thread.sleep(1000)
    }
    logger.info("Cleanup done")
  }

  private def hasMoreFiles: Boolean = {
    try {
      Streams.scope(Files.list(Paths.get(dataDir))) { dir =>
        dir.anyMatch(f => Files.isRegularFile(f))
      }
    } catch {
      case e: Exception => {
        logger.error(s"Error checking hasMoreFiles in $dataDir", e)
        true // Assuming there's more files on error to retry
      }
    }
  }
} 
Example 49
Source File: RollingFileFlow.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.persistence

import java.nio.file.Files
import java.nio.file.Paths

import akka.NotUsed
import akka.stream.Attributes
import akka.stream.FlowShape
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.stage.GraphStage
import akka.stream.stage.GraphStageLogic
import akka.stream.stage.InHandler
import akka.stream.stage.OutHandler
import akka.stream.stage.TimerGraphStageLogic
import com.netflix.atlas.core.model.Datapoint
import com.netflix.spectator.api.Registry
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.duration._

class RollingFileFlow(
  val dataDir: String,
  val rollingConf: RollingConfig,
  val registry: Registry
) extends GraphStage[FlowShape[Datapoint, NotUsed]]
    with StrictLogging {

  private val in = Inlet[Datapoint]("RollingFileSink.in")
  private val out = Outlet[NotUsed]("RollingFileSink.out")
  override val shape = FlowShape(in, out)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = {

    new TimerGraphStageLogic(shape) with InHandler with OutHandler {

      private var hourlyWriter: HourlyRollingWriter = _

      override def preStart(): Unit = {
        logger.info(s"creating sink directory: $dataDir")
        Files.createDirectories(Paths.get(dataDir))

        hourlyWriter = new HourlyRollingWriter(dataDir, rollingConf, registry)
        hourlyWriter.initialize
        // This is to trigger rollover check when writer is idle for long time: e.g. in most cases
        // file writer will be idle while hour has ended but it is still waiting for late events
        schedulePeriodically(None, 5.seconds)
      }

      override def onPush(): Unit = {
        hourlyWriter.write(grab(in))
        pull(in)
      }

      override protected def onTimer(timerKey: Any): Unit = {
        hourlyWriter.write(RollingFileWriter.RolloverCheckDatapoint)
      }

      override def onUpstreamFinish(): Unit = {
        hourlyWriter.close()
        completeStage()
      }

      override def onUpstreamFailure(ex: Throwable): Unit = {
        hourlyWriter.close()
        failStage(ex)
      }

      setHandlers(in, out, this)

      override def onPull(): Unit = {
        // Nothing to emit
        pull(in)
      }
    }
  }
} 
Example 50
Source File: Main.scala    From iep-apps   with Apache License 2.0 5 votes vote down vote up
package com.netflix.atlas.persistence

import com.google.inject.AbstractModule
import com.google.inject.Module
import com.netflix.iep.guice.GuiceHelper
import com.netflix.iep.service.ServiceManager
import com.netflix.spectator.api.NoopRegistry
import com.netflix.spectator.api.Registry
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import com.typesafe.scalalogging.StrictLogging

object Main extends StrictLogging {

  private def getBaseModules: java.util.List[Module] = {
    val modules = GuiceHelper.getModulesUsingServiceLoader
    if (!sys.env.contains("NETFLIX_ENVIRONMENT")) {
      // If we are running in a local environment provide simple version of the config
      // binding. These bindings are normally provided by the final package
      // config for the app in the production setup.
      modules.add(new AbstractModule {
        override def configure(): Unit = {
          bind(classOf[Config]).toInstance(ConfigFactory.load())
          bind(classOf[Registry]).toInstance(new NoopRegistry)
        }
      })
    }
    modules
  }

  def main(args: Array[String]): Unit = {
    try {
      val modules = getBaseModules
      modules.add(new AppModule)
      val guice = new GuiceHelper
      guice.start(modules)
      guice.getInjector.getInstance(classOf[ServiceManager])
      guice.addShutdownHook()
    } catch {
      // Send exceptions to main log file instead of wherever STDERR is sent for the process
      case t: Throwable => logger.error("fatal error on startup", t)
    }
  }
} 
Example 51
Source File: HbaseHelper.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hbase

import com.typesafe.scalalogging.StrictLogging
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.{Connection, Table}

object HbaseHelper extends StrictLogging {
  def autoclose[C <: AutoCloseable, T](closeable: C)(thunk: C => T): T = {
    try {
      thunk(closeable)
    }
    finally {
      if (closeable != null) {
        closeable.close()
      }
    }
  }

  def withTable[T](tableName: TableName)(thunk: Table => T)(implicit connection: Connection): T = {
    autoclose(connection.getTable(tableName))(thunk)
  }
} 
Example 52
Source File: ReThinkSourceReadersFactory.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.rethink.source

import java.util
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.AtomicBoolean

import com.datamountaineer.streamreactor.connect.rethink.ReThinkConnection
import com.datamountaineer.streamreactor.connect.rethink.config.{ReThinkSourceConfig, ReThinkSourceSetting, ReThinkSourceSettings}
import com.rethinkdb.RethinkDB
import com.rethinkdb.net.{Connection, Cursor}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.SchemaBuilder
import org.apache.kafka.connect.source.SourceRecord

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future

object ReThinkSourceReadersFactory {

  def apply(config: ReThinkSourceConfig, r: RethinkDB): Set[ReThinkSourceReader] = {
    val conn = Some(ReThinkConnection(r, config))
    val settings = ReThinkSourceSettings(config)
    settings.map(s => new ReThinkSourceReader(r, conn.get, s))
  }
}

class ReThinkSourceReader(rethink: RethinkDB, conn: Connection, setting: ReThinkSourceSetting)
  extends StrictLogging {

  logger.info(s"Initialising ReThink Reader for ${setting.source}")
  private val keySchema = SchemaBuilder.string().optional().build()
  private val valueSchema = ChangeFeedStructBuilder.schema
  private val sourcePartition = Map.empty[String, String]
  private val offset = Map.empty[String, String]
  private val stopFeed = new AtomicBoolean(false)
  private val handlingFeed = new AtomicBoolean(false)
  private var feed : Cursor[util.HashMap[String, String]] = _
  val queue = new LinkedBlockingQueue[SourceRecord]()
  val batchSize = setting.batchSize

  def start() = {
    feed = getChangeFeed()
    startFeed(feed)
  }

  def stop() = {
    logger.info(s"Closing change feed for ${setting.source}")
    stopFeed.set(true)
    while (handlingFeed.get()) {
      logger.debug("Waiting for feed to shutdown...")
      Thread.sleep(1000)
    }
    feed.close()
    logger.info(s"Change feed closed for ${setting.source}")
  }

  
  private def handleFeed(feed: Cursor[util.HashMap[String, String]]) = {
    handlingFeed.set(true)

    //feed.next is blocking
    while(!stopFeed.get()) {
      logger.debug(s"Waiting for next change feed event for ${setting.source}")
      val cdc = convert(feed.next().asScala.toMap)
      queue.put(cdc)
    }
    handlingFeed.set(false)
  }

  private def getChangeFeed(): Cursor[util.HashMap[String, String]] = {
    logger.info(s"Initialising change feed for ${setting.source}")
    rethink
      .db(setting.db)
      .table(setting.source)
      .changes()
      .optArg("include_states", true)
      .optArg("include_initial", setting.initialise)
      .optArg("include_types", true)
      .run(conn)
  }

  private def convert(feed: Map[String, String]) = {
    new SourceRecord(sourcePartition.asJava, offset.asJava, setting.target, keySchema, setting.source, valueSchema,
      ChangeFeedStructBuilder(feed))
  }
} 
Example 53
Source File: ReThinkSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.rethink.sink

import java.util

import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.rethink.ReThinkConnection
import com.datamountaineer.streamreactor.connect.rethink.config.{ReThinkConfigConstants, ReThinkSinkConfig, ReThinkSinkSettings}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.rethinkdb.RethinkDB
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._



    val rethink = RethinkDB.r
    initializeTables(rethink, props)
    configProps = props
  }

  def initializeTables(rethink: RethinkDB, props: util.Map[String, String]): Unit = {
    val config = ReThinkSinkConfig(props)
    val settings = ReThinkSinkSettings(config)
    val rethinkHost = config.getString(ReThinkConfigConstants.RETHINK_HOST)
    val port = config.getInt(ReThinkConfigConstants.RETHINK_PORT)

    val conn = ReThinkConnection(rethink, config)
    ReThinkHelper.checkAndCreateTables(rethink, settings, conn)
    conn.close()
  }

  override def stop(): Unit = {}

  override def version(): String = manifest.version()

  override def config(): ConfigDef = configDef
} 
Example 54
Source File: ReThinkHelper.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.rethink.sink

import com.datamountaineer.streamreactor.connect.rethink.config.ReThinkSinkSetting
import com.rethinkdb.RethinkDB
import com.rethinkdb.net.Connection
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.errors.ConnectException

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}


  def checkAndCreateTables(rethink: RethinkDB, setting: ReThinkSinkSetting, conn: Connection): Unit = {
    val isAutoCreate = setting.kcql.map(r => (r.getTarget, r.isAutoCreate)).toMap
    val tables: java.util.List[String] = rethink.db(setting.database).tableList().run(conn)

    setting.topicTableMap
      .filter({ case (_, table) => !tables.contains(table) && isAutoCreate(table).equals(false) })
      .foreach({
        case (_, table) => throw new ConnectException(s"No table called $table found in database ${setting.database} and" +
          s" it's not set for AUTOCREATE")
      })

    //create any tables that are marked for auto create
    setting
      .kcql
      .filter(r => r.isAutoCreate)
      .filterNot(r => tables.contains(r.getTarget))
      .foreach(r => {
        logger.info(s"Creating table ${r.getTarget}")


        //set primary keys if we have them
        val pk = r.getPrimaryKeys.asScala.toSet
        val pkName = if (pk.isEmpty) "id" else pk.head.getName
        logger.info(s"Setting primary as first field found: $pkName")

        val create: java.util.Map[String, Object] = rethink
          .db(setting.database)
          .tableCreate(r.getTarget)
          .optArg("primary_key", pkName)
          .run(conn)

        Try(create) match {
          case Success(_) =>
            //logger.info(create.mkString(","))
            logger.info(s"Created table ${r.getTarget}.")
          case Failure(f) =>
            logger.error(s"Failed to create table ${r.getTarget}." +
              s" Error message  ${create.asScala.mkString(",")}, ${f.getMessage}")
        }
      })
  }

} 
Example 55
Source File: ReThinkConnection.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.rethink

import java.io.{BufferedInputStream, FileInputStream}

import com.datamountaineer.streamreactor.connect.rethink.config.ReThinkConfigConstants
import com.rethinkdb.RethinkDB
import com.rethinkdb.net.Connection
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.AbstractConfig
import org.apache.kafka.connect.errors.ConnectException


object ReThinkConnection extends StrictLogging {
  def apply(r: RethinkDB, config: AbstractConfig): Connection = {

    val host = config.getString(ReThinkConfigConstants.RETHINK_HOST)
    val port = config.getInt(ReThinkConfigConstants.RETHINK_PORT)
    val username = config.getString(ReThinkConfigConstants.USERNAME)
    val password = config.getPassword(ReThinkConfigConstants.PASSWORD).value()
    val certFile = config.getString(ReThinkConfigConstants.CERT_FILE)
    val authKey = config.getPassword(ReThinkConfigConstants.AUTH_KEY)

    //java driver also catches this
    if (username.nonEmpty && certFile.nonEmpty) {
      throw new ConnectException("Username and Certificate file can not be used together.")
    }

    if ((certFile.nonEmpty && config.getPassword(ReThinkConfigConstants.AUTH_KEY).value().isEmpty)
      || certFile.isEmpty && config.getPassword(ReThinkConfigConstants.AUTH_KEY).value().nonEmpty
    ) {
      throw new ConnectException("Both the certificate file and authentication key must be set for secure TLS connections.")
    }

    val builder = r.connection()
      .hostname(host)
      .port(port)

    if (!username.isEmpty) {
      logger.info("Adding username/password credentials to connection")
      builder.user(username, password)
    }

    if (!certFile.isEmpty) {
      logger.info(s"Using certificate file ${certFile} for TLS connection, overriding any SSLContext")
      val is = new BufferedInputStream(new FileInputStream(certFile))
      builder.certFile(is)
    }

    if (!authKey.value().isEmpty) {
      logger.info("Set authorization key")
      builder.authKey(authKey.value())
    }

    builder.connect()
  }
} 
Example 56
Source File: NestedGroupConverter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Field, Schema}
import org.apache.parquet.io.api.{Converter, GroupConverter}

import scala.collection.JavaConverters._

class NestedGroupConverter(schema: Schema,
                           field: Field,
                           parentBuilder: scala.collection.mutable.Map[String, Any])
  extends GroupConverter with StrictLogging {
  private[parquet] val builder = scala.collection.mutable.Map.empty[String, Any]
  private val converters = schema.fields.asScala.map(Converters.get(_, builder)).toIndexedSeq
  override def getConverter(k: Int): Converter = converters(k)
  override def start(): Unit = builder.clear()
  override def end(): Unit = parentBuilder.put(field.name, builder.result)
} 
Example 57
Source File: RootGroupConverter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Schema, Struct}
import org.apache.parquet.io.api.{Converter, GroupConverter}

import scala.collection.JavaConverters._

class RootGroupConverter(schema: Schema) extends GroupConverter with StrictLogging {
  require(schema.`type`() == Schema.Type.STRUCT)

  var struct: Struct = _
  private val builder = scala.collection.mutable.Map.empty[String, Any]
  private val converters = schema.fields.asScala.map(Converters.get(_, builder)).toIndexedSeq

  override def getConverter(k: Int): Converter = converters(k)
  override def start(): Unit = builder.clear()
  override def end(): Unit = struct = {
    val struct = new Struct(schema)
    schema.fields.asScala.map { field =>
      val value = builder.getOrElse(field.name, null)
      try {
        struct.put(field, value)
      } catch {
        case t: Exception =>
          throw t
      }
    }
    struct
  }
} 
Example 58
Source File: HiveSourceConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.source

import java.util

import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.landoop.streamreactor.connect.hive.sink.config.HiveSinkConfigDef
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.source.SourceConnector

import scala.collection.JavaConverters._

class HiveSourceConnector extends SourceConnector with StrictLogging {

  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)
  private var props: util.Map[String, String] = _

  override def version(): String = manifest.version()
  override def taskClass(): Class[_ <: Task] = classOf[HiveSourceTask]
  override def config(): ConfigDef = HiveSinkConfigDef.config

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(s"Creating hive sink connector")
    this.props = props
  }

  override def stop(): Unit = ()

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Creating $maxTasks tasks config")
    List.fill(maxTasks)(props).asJava
  }
} 
Example 59
Source File: AsyncFunctionLoop.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive

import java.util.concurrent.{Executors, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.duration.Duration

class AsyncFunctionLoop(interval: Duration, description: String)(thunk: => Unit)
  extends AutoCloseable
    with StrictLogging {

  private val running = new AtomicBoolean(false)
  private val executorService = Executors.newFixedThreadPool(1)

  def start(): Unit = {
    if (!running.compareAndSet(false, true)) {
      throw new IllegalStateException(s"$description already running.")
    }
    logger.info(s"Starting $description loop with an interval of ${interval.toMillis}ms.")
    executorService.submit(new Runnable {
      override def run(): Unit = {
        while (running.get()) {
          try {
            Thread.sleep(interval.toMillis)
            thunk
          }
          catch {
            case _: InterruptedException =>
            case t: Throwable =>
              logger.warn("Failed to renew the Kerberos ticket", t)
          }
        }
      }
    })
  }

  override def close(): Unit = {
    if (running.compareAndSet(true, false)) {
      executorService.shutdownNow()
      executorService.awaitTermination(10000, TimeUnit.MILLISECONDS)
    }
  }
} 
Example 60
Source File: InfluxDbWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.influx.writers

import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.influx.config.InfluxSettings
import com.datamountaineer.streamreactor.connect.influx.{NanoClock, ValidateStringParameterFn}
import com.datamountaineer.streamreactor.connect.sink.DbWriter
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.sink.SinkRecord
import org.influxdb.InfluxDBFactory

import scala.util.Try

class InfluxDbWriter(settings: InfluxSettings) extends DbWriter with StrictLogging with ErrorHandler {

  ValidateStringParameterFn(settings.connectionUrl, "settings")
  ValidateStringParameterFn(settings.user, "settings")

  //initialize error tracker
  initialize(settings.maxRetries, settings.errorPolicy)
  private val influxDB = InfluxDBFactory.connect(settings.connectionUrl, settings.user, settings.password)
  private val builder = new InfluxBatchPointsBuilder(settings, new NanoClock())

  override def write(records: Seq[SinkRecord]): Unit = {
    if (records.isEmpty) {
      logger.debug("No records received.")
    } else {
      handleTry(
        builder
          .build(records)
          .flatMap { batchPoints =>
            logger.debug(s"Writing ${batchPoints.getPoints.size()} points to the database...")
            Try(influxDB.write(batchPoints))
          }.map(_ => logger.debug("Writing complete")))
    }
  }

  override def close(): Unit = {}
} 
Example 61
Source File: AsyncFunctionLoop.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hbase.kerberos.utils

import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.{Executors, TimeUnit}

import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.duration.Duration

class AsyncFunctionLoop(interval: Duration, description: String)(thunk: => Unit)
  extends AutoCloseable
    with StrictLogging {

  private val running = new AtomicBoolean(false)
  private val executorService = Executors.newSingleThreadExecutor

  def start(): Unit = {
    if (!running.compareAndSet(false, true)) {
      throw new IllegalStateException(s"$description already running.")
    }
    logger.info(s"Starting $description loop with an interval of ${interval.toMillis}ms.")
    executorService.submit(new Runnable {
      override def run(): Unit = {
        while (running.get()) {
          try {
            Thread.sleep(interval.toMillis)
            thunk
          }
          catch {
            case _: InterruptedException =>
            case t: Throwable =>
              logger.warn("Failed to renew the Kerberos ticket", t)
          }
        }
      }
    })
  }

  override def close(): Unit = {
    if (running.compareAndSet(true, false)) {
      executorService.shutdownNow()
      executorService.awaitTermination(10000, TimeUnit.MILLISECONDS)
    }
  }
} 
Example 62
Source File: StructFieldsExtractorBytes.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hbase

import java.text.SimpleDateFormat
import java.util.TimeZone

import com.datamountaineer.streamreactor.connect.hbase.BytesHelper._
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data._

import scala.collection.JavaConverters._

trait FieldsValuesExtractor {
  def get(struct: Struct): Seq[(String, Array[Byte])]
}

case class StructFieldsExtractorBytes(includeAllFields: Boolean, fieldsAliasMap: Map[String, String]) extends FieldsValuesExtractor with StrictLogging {

  def get(struct: Struct): Seq[(String, Array[Byte])] = {
    val schema = struct.schema()
    val fields: Seq[Field] = if (includeAllFields) {
      schema.fields().asScala
    }
    else {
      val selectedFields = schema.fields().asScala.filter(f => fieldsAliasMap.contains(f.name()))
      val diffSet = fieldsAliasMap.keySet.diff(selectedFields.map(_.name()).toSet)
      if (diffSet.nonEmpty) {
        val errMsg = s"Following columns ${diffSet.mkString(",")} have not been found. Available columns:${fieldsAliasMap.keys.mkString(",")}"
        logger.error(errMsg)
        sys.error(errMsg)
      }
      selectedFields
    }

    val fieldsAndValues = fields.flatMap(field =>
      getFieldBytes(field, struct).map(bytes => fieldsAliasMap.getOrElse(field.name(), field.name()) -> bytes))

    fieldsAndValues
  }

  private def getFieldBytes(field: Field, struct: Struct): Option[Array[Byte]] = {
    Option(struct.get(field))
      .map { value =>
        Option(field.schema().name()).collect {
          case Decimal.LOGICAL_NAME =>
            value.asInstanceOf[Any] match {
              case _:java.math.BigDecimal => value.fromBigDecimal()
              case arr: Array[Byte] => Decimal.toLogical(field.schema, arr).asInstanceOf[Any].fromBigDecimal()
              case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
            }
          case Time.LOGICAL_NAME =>
            value.asInstanceOf[Any] match {
              case i: Int => StructFieldsExtractorBytes.TimeFormat.format(Time.toLogical(field.schema, i)).asInstanceOf[Any].fromString()
              case d:java.util.Date => StructFieldsExtractorBytes.TimeFormat.format(d).asInstanceOf[Any].fromString()
              case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
            }

          case Timestamp.LOGICAL_NAME =>
            value.asInstanceOf[Any] match {
              case d:java.util.Date => StructFieldsExtractorBytes.DateFormat.format(d).asInstanceOf[Any].fromString()
              case l: Long => StructFieldsExtractorBytes.DateFormat.format(Timestamp.toLogical(field.schema, l)).asInstanceOf[Any].fromString()
              case _ => throw new IllegalArgumentException(s"${field.name()} is not handled for value:$value")
            }
        }.getOrElse {

          field.schema().`type`() match {
            case Schema.Type.BOOLEAN => value.fromBoolean()
            case Schema.Type.BYTES => value.fromBytes()
            case Schema.Type.FLOAT32 => value.fromFloat()
            case Schema.Type.FLOAT64 => value.fromDouble()
            case Schema.Type.INT8 => value.fromByte()
            case Schema.Type.INT16 => value.fromShort()
            case Schema.Type.INT32 => value.fromInt()
            case Schema.Type.INT64 => value.fromLong()
            case Schema.Type.STRING => value.fromString()
            case other => sys.error(s"$other is not a recognized schema!")
          }
        }
      }
  }
}


object StructFieldsExtractorBytes {
  val DateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
  val TimeFormat = new SimpleDateFormat("HH:mm:ss.SSSZ")

  DateFormat.setTimeZone(TimeZone.getTimeZone("UTC"))
} 
Example 63
Source File: ChangeFeedStructBuilder.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.rethink.source

import com.fasterxml.jackson.databind.ObjectMapper
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Schema, SchemaBuilder, Struct}



object ChangeFeedStructBuilder extends StrictLogging {

  val mapper = new ObjectMapper()
  val oldVal = "old_val"
  val newVal = "new_val"
  val state = "state"
  val `type` = "type"

  val schema: Schema = SchemaBuilder.struct.name("ReThinkChangeFeed")
    .version(1)
    .field(state, Schema.OPTIONAL_STRING_SCHEMA)
    .field(oldVal, Schema.OPTIONAL_STRING_SCHEMA)
    .field(newVal, Schema.OPTIONAL_STRING_SCHEMA)
    .field(`type`, Schema.OPTIONAL_STRING_SCHEMA)
    .build

  def apply(hm: Map[String, Object]): Struct = {
    val struct = new Struct(schema)
    hm.foreach({ case (k, v) => if (v != null) struct.put(k, v.toString) })
    struct
  }
} 
Example 64
Source File: JMSReader.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.jms.source.readers

import com.datamountaineer.streamreactor.connect.converters.source.Converter
import com.datamountaineer.streamreactor.connect.jms.JMSSessionProvider
import com.datamountaineer.streamreactor.connect.jms.config.JMSSettings
import com.datamountaineer.streamreactor.connect.jms.source.domain.JMSStructMessage
import com.typesafe.scalalogging.StrictLogging
import javax.jms.{Message, MessageConsumer}
import org.apache.kafka.connect.source.SourceRecord

import scala.util.Try


class JMSReader(settings: JMSSettings) extends StrictLogging {

  val provider = JMSSessionProvider(settings)
  provider.start()
  val consumers: Vector[(String, MessageConsumer)] = (provider.queueConsumers ++ provider.topicsConsumers).toVector
  val convertersMap: Map[String, Option[Converter]] = settings.settings.map(s => (s.source, s.sourceConverters)).toMap
  val topicsMap: Map[String, String] = settings.settings.map(s => (s.source, s.target)).toMap

  def poll(): Vector[(Message, SourceRecord)] = {
    val messages = consumers
      .flatMap({ case (source, consumer) =>
        (0 to settings.batchSize)
          .flatMap(_ => Option(consumer.receiveNoWait()))
          .map(m => (m, convert(source, topicsMap(source), m)))
      })

    messages
  }

  def convert(source: String, target: String, message: Message): SourceRecord = {
    convertersMap(source).getOrElse(None) match {
      case c: Converter => c.convert(target, source, message.getJMSMessageID, JMSStructMessage.getPayload(message))
      case None => JMSStructMessage.getStruct(target, message)
    }
  }

  def stop: Try[Unit] = provider.close()
}

object JMSReader {
  def apply(settings: JMSSettings): JMSReader = new JMSReader(settings)
} 
Example 65
Source File: JMSSourceConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.jms.source

import java.util

import com.datamountaineer.streamreactor.connect.jms.config.{JMSConfig, JMSConfigConstants}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.source.SourceConnector
import org.apache.kafka.connect.util.ConnectorUtils

import scala.collection.JavaConverters._


class JMSSourceConnector extends SourceConnector with StrictLogging {
  private var configProps: util.Map[String, String] = _
  private val configDef = JMSConfig.config
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def taskClass(): Class[_ <: Task] = classOf[JMSSourceTask]

  def kcqlTaskScaling(maxTasks: Int): util.List[util.Map[String, String]] = {
    val raw = configProps.get(JMSConfigConstants.KCQL)
    require(raw != null && !raw.isEmpty, s"No ${JMSConfigConstants.KCQL} provided!")

    //sql1, sql2
    val kcqls = raw.split(";")
    val groups = ConnectorUtils.groupPartitions(kcqls.toList.asJava, maxTasks).asScala

    //split up the kcql statement based on the number of tasks.
    groups
      .filterNot(_.isEmpty)
      .map { g =>
        val taskConfigs = new java.util.HashMap[String, String]
        taskConfigs.putAll(configProps)
        taskConfigs.put(JMSConfigConstants.KCQL, g.asScala.mkString(";")) //overwrite
        taskConfigs.asScala.toMap.asJava
      }
  }.asJava

  def defaultTaskScaling(maxTasks: Int): util.List[util.Map[String, String]] = {
    val raw = configProps.get(JMSConfigConstants.KCQL)
    require(raw != null && !raw.isEmpty, s"No ${JMSConfigConstants.KCQL} provided!")
    (1 to maxTasks).map { _ =>
      val taskConfigs: util.Map[String, String] = new java.util.HashMap[String, String]
      taskConfigs.putAll(configProps)
      taskConfigs
    }.toList.asJava
  }

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    val config = new JMSConfig(configProps)
    val scaleType = config.getString(JMSConfigConstants.TASK_PARALLELIZATION_TYPE).toLowerCase()
    if (scaleType == JMSConfigConstants.TASK_PARALLELIZATION_TYPE_DEFAULT) {
      kcqlTaskScaling(maxTasks)
    } else defaultTaskScaling(maxTasks)
  }

  override def config(): ConfigDef = configDef

  override def start(props: util.Map[String, String]): Unit = {
    val config = new JMSConfig(props)
    configProps = config.props
  }

  override def stop(): Unit = {}

  override def version(): String = manifest.version()
} 
Example 66
Source File: JMSWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.jms.sink.writer

import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.jms.JMSSessionProvider
import com.datamountaineer.streamreactor.connect.jms.config.{JMSSetting, JMSSettings}
import com.datamountaineer.streamreactor.connect.jms.sink.converters.{JMSHeadersConverterWrapper, JMSMessageConverter, JMSMessageConverterFn}
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.typesafe.scalalogging.StrictLogging
import javax.jms._
import org.apache.kafka.connect.sink.SinkRecord

import scala.util.{Failure, Success, Try}

case class JMSWriter(settings: JMSSettings) extends AutoCloseable with ConverterUtil with ErrorHandler with StrictLogging {

  val provider = JMSSessionProvider(settings, sink = true)
  provider.start()
  val producers: Map[String, MessageProducer] = provider.queueProducers ++ provider.topicProducers
  val converterMap: Map[String, JMSMessageConverter] = settings.settings
    .map(s => (s.source, JMSHeadersConverterWrapper(s.headers, JMSMessageConverterFn(s.format)))).toMap
  val settingsMap: Map[String, JMSSetting] = settings.settings.map(s => (s.source, s)).toMap

  //initialize error tracker
  initialize(settings.retries, settings.errorPolicy)

  
  def send(messages: Seq[(String, Message)]): Unit = {
    messages.foreach({ case (name, message) => producers(name).send(message)})
  }

  override def close(): Unit = provider.close()
} 
Example 67
Source File: CassandraWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.cassandra.sink

import com.datamountaineer.streamreactor.connect.cassandra.CassandraConnection
import com.datamountaineer.streamreactor.connect.cassandra.config.{CassandraConfigConstants, CassandraConfigSink, CassandraSettings}
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.SinkTaskContext

import scala.util.{Failure, Success, Try}

//Factory to build
object CassandraWriter extends StrictLogging {
  def apply(connectorConfig: CassandraConfigSink, context: SinkTaskContext): CassandraJsonWriter = {

    val connection = Try(CassandraConnection(connectorConfig)) match {
      case Success(s) => s
      case Failure(f) => throw new ConnectException(s"Couldn't connect to Cassandra.", f)
    }

    val settings = CassandraSettings.configureSink(connectorConfig)
    //if error policy is retry set retry interval
    if (settings.errorPolicy.equals(ErrorPolicyEnum.RETRY)) {
      context.timeout(connectorConfig.getString(CassandraConfigConstants.ERROR_RETRY_INTERVAL).toLong)
    }

    new CassandraJsonWriter(connection, settings)
  }
} 
Example 68
Source File: CassandraSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.cassandra.sink

import java.util

import com.datamountaineer.streamreactor.connect.cassandra.config.{CassandraConfigConstants, CassandraConfigSink}
import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._
import scala.util.{Failure, Try}


  override def start(props: util.Map[String, String]): Unit = {
    //check input topics
    Helpers.checkInputTopics(CassandraConfigConstants.KCQL, props.asScala.toMap)
    configProps = props
    Try(new CassandraConfigSink(props)) match {
      case Failure(f) =>
        throw new ConnectException(s"Couldn't start Cassandra sink due to configuration error: ${f.getMessage}", f)
      case _ =>
    }
  }

  override def stop(): Unit = {}

  override def version(): String = manifest.version()

  override def config(): ConfigDef = configDef
} 
Example 69
Source File: CassandraSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.cassandra.sink

import java.util

import com.datamountaineer.streamreactor.connect.cassandra.config.{CassandraConfigSink, CassandraSettings}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}



  override def stop(): Unit = {
    logger.info("Stopping Cassandra sink.")
    writer.foreach(w => w.close())
    if (enableProgress) {
      progressCounter.empty
    }
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {}

  override def version: String = manifest.version()
} 
Example 70
Source File: MongoSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mongodb.sink

import java.util

import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.mongodb.config.{MongoConfig, MongoConfigConstants}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.{Config, ConfigDef}
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._
import scala.util.{Failure, Try}


  override def start(props: util.Map[String, String]): Unit = {
    Helpers.checkInputTopics(MongoConfigConstants.KCQL_CONFIG, props.asScala.toMap)
    Try(MongoConfig(props)) match {
      case Failure(f) => throw new ConnectException(s"Couldn't start Mongo sink due to configuration error: ${f.getMessage}", f)
      case _ =>
    }

    configProps = props
  }

  override def stop(): Unit = {}

  override def version(): String = manifest.version()

  override def config(): ConfigDef = MongoConfig.config

  override def validate(connectorConfigs: util.Map[String, String]): Config = super.validate(connectorConfigs)
} 
Example 71
Source File: MongoSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mongodb.sink

import java.util

import com.datamountaineer.streamreactor.connect.mongodb.config.{MongoConfig, MongoConfigConstants}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}


  override def put(records: util.Collection[SinkRecord]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    val seq = records.asScala.toVector
    writer.foreach(w => w.write(seq))

    if (enableProgress) {
      progressCounter.update(seq)
    }
  }

  override def stop(): Unit = {
    logger.info("Stopping Mongo Database sink.")
    writer.foreach(w => w.close())
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {}

  override def version: String = manifest.version()
} 
Example 72
Source File: DocumentDbSinkSettings.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.azure.documentdb.config

import com.datamountaineer.kcql.Kcql
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicy
import com.microsoft.azure.documentdb.ConsistencyLevel
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigException

case class DocumentDbSinkSettings(endpoint: String,
                                  masterKey: String,
                                  database: String,
                                  kcql: Seq[Kcql],
                                  keyBuilderMap: Map[String, Set[String]],
                                  fields: Map[String, Map[String, String]],
                                  ignoredField: Map[String, Set[String]],
                                  errorPolicy: ErrorPolicy,
                                  consistency: ConsistencyLevel,
                                  createDatabase: Boolean,
                                  proxy: Option[String],
                                  taskRetries: Int = DocumentDbConfigConstants.NBR_OF_RETIRES_DEFAULT
                                  ) {

}


object DocumentDbSinkSettings extends StrictLogging {

  def apply(config: DocumentDbConfig): DocumentDbSinkSettings = {
    val endpoint = config.getString(DocumentDbConfigConstants.CONNECTION_CONFIG)
    require(endpoint.nonEmpty, s"Invalid endpoint provided.${DocumentDbConfigConstants.CONNECTION_CONFIG_DOC}")

    val masterKey = Option(config.getPassword(DocumentDbConfigConstants.MASTER_KEY_CONFIG))
      .map(_.value())
      .getOrElse(throw new ConfigException(s"Missing ${DocumentDbConfigConstants.MASTER_KEY_CONFIG}"))
    require(masterKey.trim.nonEmpty, s"Invalid ${DocumentDbConfigConstants.MASTER_KEY_CONFIG}")

    val database = config.getDatabase

    if (database.isEmpty) {
      throw new ConfigException(s"Missing ${DocumentDbConfigConstants.DATABASE_CONFIG}.")
    }

    val kcql = config.getKCQL
    val errorPolicy= config.getErrorPolicy
    val retries = config.getNumberRetries
    val rowKeyBuilderMap = config.getUpsertKeys()
    val fieldsMap = config.getFieldsMap()
    val ignoreFields = config.getIgnoreFieldsMap()
    val consistencyLevel = config.getConsistencyLevel.get

    new DocumentDbSinkSettings(endpoint,
      masterKey,
      database,
      kcql.toSeq,
      rowKeyBuilderMap,
      fieldsMap,
      ignoreFields,
      errorPolicy,
      consistencyLevel,
      config.getBoolean(DocumentDbConfigConstants.CREATE_DATABASE_CONFIG),
      Option(config.getString(DocumentDbConfigConstants.PROXY_HOST_CONFIG)),
      retries)
  }
} 
Example 73
Source File: DocumentDbSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.azure.documentdb.sink

import java.util

import com.datamountaineer.streamreactor.connect.azure.documentdb.DocumentClientProvider
import com.datamountaineer.streamreactor.connect.azure.documentdb.config.{DocumentDbConfig, DocumentDbConfigConstants, DocumentDbSinkSettings}
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.microsoft.azure.documentdb.DocumentClient
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}


  override def put(records: util.Collection[SinkRecord]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    val seq = records.asScala.toVector
    writer.foreach(w => w.write(seq))

    if (enableProgress) {
      progressCounter.update(seq)
    }
  }

  override def stop(): Unit = {
    logger.info("Stopping Azure Document DB sink.")
    writer.foreach(w => w.close())
    progressCounter.empty()
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {}

  override def version: String = manifest.version()
} 
Example 74
Source File: PubSubSupport.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.redis.sink.writer

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging

import scala.collection.JavaConverters._

trait PubSubSupport extends StrictLogging {

  // How to 'score' each message
  def getChannelField(kcqlConfig: Kcql): String = {
    val pubSubParams = kcqlConfig.getStoredAsParameters.asScala
    val channelField = if (pubSubParams.keys.exists(k => k.equalsIgnoreCase("channel")))
      pubSubParams.find { case (k, _) => k.equalsIgnoreCase("channel") }.get._2
    else {
      logger.info("You have not defined a 'channel' field. We'll try to fall back to 'channel' field")
      "channel"
    }
    channelField
  }

//   assert(SS.isValid, "The SortedSet definition at Redis accepts only case sensitive alphabetic characters")

} 
Example 75
Source File: RootGroupConverter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Schema, Struct}
import org.apache.parquet.io.api.{Converter, GroupConverter}

import scala.collection.JavaConverters._

class RootGroupConverter(schema: Schema) extends GroupConverter with StrictLogging {
  require(schema.`type`() == Schema.Type.STRUCT)

  var struct: Struct = _
  private val builder = scala.collection.mutable.Map.empty[String, Any]
  private val converters = schema.fields.asScala.map(Converters.get(_, builder)).toIndexedSeq

  override def getConverter(k: Int): Converter = converters(k)
  override def start(): Unit = builder.clear()
  override def end(): Unit = struct = {
    val struct = new Struct(schema)
    schema.fields.asScala.map { field =>
      val value = builder.getOrElse(field.name, null)
      try {
        struct.put(field, value)
      } catch {
        case t: Exception =>
          throw t
      }
    }
    struct
  }
} 
Example 76
Source File: DocumentDbWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.azure.documentdb.sink

import com.datamountaineer.kcql.WriteModeEnum
import com.datamountaineer.streamreactor.connect.azure.documentdb.DocumentClientProvider
import com.datamountaineer.streamreactor.connect.azure.documentdb.config.{DocumentDbConfig, DocumentDbConfigConstants, DocumentDbSinkSettings}
import com.datamountaineer.streamreactor.connect.errors.{ErrorHandler, ErrorPolicyEnum}
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.microsoft.azure.documentdb._
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.sink.{SinkRecord, SinkTaskContext}

import scala.util.Failure


  private def insert(records: Seq[SinkRecord]) = {
    try {
      records.groupBy(_.topic()).foreach { case (_, groupedRecords) =>
        groupedRecords.foreach { record =>
          val (document, keysAndValues) = SinkRecordToDocument(record, settings.keyBuilderMap.getOrElse(record.topic(), Set.empty))(settings)

          val key = keysAndValues.flatMap { case (_, v) => Option(v) }.mkString(".")
          if (key.nonEmpty) {
            document.setId(key)
          }
          val config = configMap.getOrElse(record.topic(), sys.error(s"${record.topic()} is not handled by the configuration."))
          config.getWriteMode match {
            case WriteModeEnum.INSERT =>
              documentClient.createDocument(s"dbs/${settings.database}/colls/${config.getTarget}", document, requestOptionsInsert, key.nonEmpty).getResource

            case WriteModeEnum.UPSERT =>
              documentClient.upsertDocument(s"dbs/${settings.database}/colls/${config.getTarget}", document, requestOptionsInsert, key.nonEmpty).getResource
          }
        }
      }
    }
    catch {
      case t: Throwable =>
        logger.error(s"There was an error inserting the records ${t.getMessage}", t)
        handleTry(Failure(t))
    }
  }

  def close(): Unit = {
    logger.info("Shutting down Document DB writer.")
    documentClient.close()
  }
}


//Factory to build
object DocumentDbWriter extends StrictLogging {
  def apply(connectorConfig: DocumentDbConfig, context: SinkTaskContext): DocumentDbWriter = {

    implicit val settings = DocumentDbSinkSettings(connectorConfig)
    //if error policy is retry set retry interval
    if (settings.errorPolicy.equals(ErrorPolicyEnum.RETRY)) {
      context.timeout(connectorConfig.getLong(DocumentDbConfigConstants.ERROR_RETRY_INTERVAL_CONFIG))
    }

    logger.info(s"Initialising Document Db writer.")
    val provider = DocumentClientProvider.get(settings)
    new DocumentDbWriter(settings, provider)
  }
} 
Example 77
Source File: AsyncFunctionLoop.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive

import java.util.concurrent.{Executors, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.duration.Duration

class AsyncFunctionLoop(interval: Duration, description: String)(thunk: => Unit)
  extends AutoCloseable
    with StrictLogging {

  private val running = new AtomicBoolean(false)
  private val executorService = Executors.newFixedThreadPool(1)

  def start(): Unit = {
    if (!running.compareAndSet(false, true)) {
      throw new IllegalStateException(s"$description already running.")
    }
    logger.info(s"Starting $description loop with an interval of ${interval.toMillis}ms.")
    executorService.submit(new Runnable {
      override def run(): Unit = {
        while (running.get()) {
          try {
            Thread.sleep(interval.toMillis)
            thunk
          }
          catch {
            case _: InterruptedException =>
            case t: Throwable =>
              logger.warn("Failed to renew the Kerberos ticket", t)
          }
        }
      }
    })
  }

  override def close(): Unit = {
    if (running.compareAndSet(true, false)) {
      executorService.shutdownNow()
      executorService.awaitTermination(10000, TimeUnit.MILLISECONDS)
    }
  }
} 
Example 78
Source File: OrcSink.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.orc

import com.landoop.streamreactor.connect.hive.orc.vectors.{OrcVectorWriter, StructVectorWriter}
import com.landoop.streamreactor.connect.hive.{OrcSinkConfig, StructUtils}
import com.typesafe.scalalogging.StrictLogging
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector
import org.apache.kafka.connect.data.{Schema, Struct}

import scala.collection.JavaConverters._

class OrcSink(path: Path,
              schema: Schema,
              config: OrcSinkConfig)(implicit fs: FileSystem) extends StrictLogging {

  private val typeDescription = OrcSchemas.toOrc(schema)
  private val structWriter = new StructVectorWriter(typeDescription.getChildren.asScala.map(OrcVectorWriter.fromSchema))
  private val batch = typeDescription.createRowBatch(config.batchSize)
  private val vector = new StructColumnVector(batch.numCols, batch.cols: _*)
  private val orcWriter = createOrcWriter(path, typeDescription, config)
  private var n = 0

  def flush(): Unit = {
    logger.debug(s"Writing orc batch [size=$n, path=$path]")
    batch.size = n
    orcWriter.addRowBatch(batch)
    orcWriter.writeIntermediateFooter
    batch.reset()
    n = 0
  }

  def write(struct: Struct): Unit = {
    structWriter.write(vector, n, Some(StructUtils.extractValues(struct)))
    n = n + 1
    if (n == config.batchSize)
      flush()
  }

  def close(): Unit = {
    if (n > 0)
      flush()
    orcWriter.close()
  }
} 
Example 79
Source File: OrcSource.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.orc

import com.landoop.streamreactor.connect.hive.OrcSourceConfig
import com.landoop.streamreactor.connect.hive.orc.vectors.OrcVectorReader.fromSchema
import com.landoop.streamreactor.connect.hive.orc.vectors.StructVectorReader
import com.typesafe.scalalogging.StrictLogging
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.ql.exec.vector.{StructColumnVector, VectorizedRowBatch}
import org.apache.kafka.connect.data.Struct
import org.apache.orc.OrcFile.ReaderOptions
import org.apache.orc.{OrcFile, Reader}

import scala.collection.JavaConverters._

class OrcSource(path: Path, config: OrcSourceConfig)(implicit fs: FileSystem) extends StrictLogging {

  private val reader = OrcFile.createReader(path, new ReaderOptions(fs.getConf))

  private val typeDescription = reader.getSchema
  private val schema = OrcSchemas.toKafka(typeDescription)

  private val readers = typeDescription.getChildren.asScala.map(fromSchema)
  private val vectorReader = new StructVectorReader(readers.toIndexedSeq, typeDescription)

  private val batch = typeDescription.createRowBatch()
  private val recordReader = reader.rows(new Reader.Options())

  def close(): Unit = {
    recordReader.close()
  }

  def iterator: Iterator[Struct] = new Iterator[Struct] {
    var iter = new BatchIterator(batch)
    override def hasNext: Boolean = iter.hasNext || {
      batch.reset()
      recordReader.nextBatch(batch)
      iter = new BatchIterator(batch)
      !batch.endOfFile && batch.size > 0 && iter.hasNext
    }
    override def next(): Struct = iter.next()
  }

  // iterates over a batch, be careful not to mutate the batch while it is being iterated
  class BatchIterator(batch: VectorizedRowBatch) extends Iterator[Struct] {
    var offset = 0
    val vector = new StructColumnVector(batch.numCols, batch.cols: _*)
    override def hasNext: Boolean = offset < batch.size
    override def next(): Struct = {
      val struct = vectorReader.read(offset, vector)
      offset = offset + 1
      struct.orNull
    }
  }
} 
Example 80
Source File: MqttSSLSocketFactory.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mqtt.source

import java.io.FileReader
import java.security.{KeyStore, Security}

import com.typesafe.scalalogging.StrictLogging
import javax.net.ssl.{KeyManagerFactory, SSLContext, SSLSocketFactory, TrustManagerFactory}
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.openssl.jcajce.{JcaPEMKeyConverter, JcePEMDecryptorProviderBuilder}
import org.bouncycastle.openssl.{PEMEncryptedKeyPair, PEMKeyPair, PEMParser}


object MqttSSLSocketFactory extends StrictLogging {
  def apply(caCrtFile: String,
            crtFile: String,
            keyFile: String,
            password: String): SSLSocketFactory = {
    try {

      
      context.getSocketFactory
    }
    catch {
      case e: Exception =>
        logger.warn(e.getMessage, e)
        null
    }
  }
} 
Example 81
Source File: MqttSourceTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mqtt.source

import java.io.File
import java.util

import com.datamountaineer.streamreactor.connect.converters.source.Converter
import com.datamountaineer.streamreactor.connect.mqtt.config.{MqttConfigConstants, MqttSourceConfig, MqttSourceSettings}
import com.datamountaineer.streamreactor.connect.mqtt.connection.MqttClientConnectionFn
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigException
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

class MqttSourceTask extends SourceTask with StrictLogging {
  private val progressCounter = new ProgressCounter
  private var enableProgress: Boolean = false
  private var mqttManager: Option[MqttManager] = None
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {

    logger.info(scala.io.Source.fromInputStream(this.getClass.getResourceAsStream("/mqtt-source-ascii.txt")).mkString + s" $version")
    logger.info(manifest.printManifest())

    val conf = if (context.configs().isEmpty) props else context.configs()

    val settings = MqttSourceSettings(MqttSourceConfig(conf))

    settings.sslCACertFile.foreach { file =>
      if (!new File(file).exists()) {
        throw new ConfigException(s"${MqttConfigConstants.SSL_CA_CERT_CONFIG} is invalid. Can't locate $file")
      }
    }

    settings.sslCertFile.foreach { file =>
      if (!new File(file).exists()) {
        throw new ConfigException(s"${MqttConfigConstants.SSL_CERT_CONFIG} is invalid. Can't locate $file")
      }
    }

    settings.sslCertKeyFile.foreach { file =>
      if (!new File(file).exists()) {
        throw new ConfigException(s"${MqttConfigConstants.SSL_CERT_KEY_CONFIG} is invalid. Can't locate $file")
      }
    }

    val convertersMap = settings.sourcesToConverters.map { case (topic, clazz) =>
      logger.info(s"Creating converter instance for $clazz")
      val converter = Try(Class.forName(clazz).newInstance()) match {
        case Success(value) => value.asInstanceOf[Converter]
        case Failure(_) => throw new ConfigException(s"Invalid ${MqttConfigConstants.KCQL_CONFIG} is invalid. $clazz should have an empty ctor!")
      }
      import scala.collection.JavaConverters._
      converter.initialize(conf.asScala.toMap)
      topic -> converter
    }

    logger.info("Starting Mqtt source...")
    mqttManager = Some(new MqttManager(MqttClientConnectionFn.apply, convertersMap, settings))
    enableProgress = settings.enableProgress
  }

  
  override def stop(): Unit = {
    logger.info("Stopping Mqtt source.")
    mqttManager.foreach(_.close())
    progressCounter.empty
  }

  override def version: String = manifest.version()
} 
Example 82
Source File: MqttClientConnectionFn.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mqtt.connection

import com.datamountaineer.streamreactor.connect.mqtt.config.{MqttSinkSettings, MqttSourceSettings}
import com.datamountaineer.streamreactor.connect.mqtt.source.MqttSSLSocketFactory
import com.typesafe.scalalogging.StrictLogging
import org.eclipse.paho.client.mqttv3.{MqttClient, MqttConnectOptions}
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence

object MqttClientConnectionFn extends StrictLogging {
  def apply(settings: MqttSourceSettings): MqttConnectOptions = {
    {
      buildBaseClient(
        settings.connectionTimeout,
        settings.keepAliveInterval,
        settings.cleanSession,
        settings.user,
        settings.password,
        settings.sslCertFile,
        settings.sslCACertFile,
        settings.sslCertKeyFile,
        settings.connection
      )
    }
  }

  def apply(settings: MqttSinkSettings): MqttClient = {
    val options = buildBaseClient(
      settings.connectionTimeout,
      settings.keepAliveInterval,
      settings.cleanSession,
      settings.user,
      settings.password,
      settings.sslCertFile,
      settings.sslCACertFile,
      settings.sslCertKeyFile,
      settings.connection
    )

    val servers = settings.connection.split(',').map(_.trim).filter(_.nonEmpty)
    val c = new MqttClient(servers.head, settings.clientId, new MemoryPersistence())
    logger.info(s"Connecting to ${settings.connection}")
    c.connect(options)
    logger.info(s"Connected to ${settings.connection} as ${settings.clientId}")
    c
  }

  def buildBaseClient(connectionTimeout: Int,
                      keepAliveInterval: Int,
                      cleanSession: Boolean,
                      username: Option[String],
                      password: Option[String],
                      sslCertFile: Option[String],
                      sslCACertFile: Option[String],
                      sslCertKeyFile: Option[String],
                      connection:String): MqttConnectOptions = {
    val options = new MqttConnectOptions()
    options.setConnectionTimeout(connectionTimeout)
    options.setKeepAliveInterval(keepAliveInterval)
    options.setCleanSession(cleanSession)
    username.foreach(n => options.setUserName(n))
    password.foreach(p => options.setPassword(p.toCharArray))
    options.setAutomaticReconnect(true)

    val servers = connection.split(',').map(_.trim).filter(_.nonEmpty)
    options.setServerURIs(servers)

    sslCertFile.foreach { _ =>
      options.setSocketFactory(
        MqttSSLSocketFactory(sslCACertFile.get, sslCertFile.get, sslCertKeyFile.get, "")
      )
    }

    options
  }
} 
Example 83
Source File: MqttSinkTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mqtt.sink

import java.util

import com.datamountaineer.streamreactor.connect.converters.sink.Converter
import com.datamountaineer.streamreactor.connect.errors.ErrorPolicyEnum
import com.datamountaineer.streamreactor.connect.mqtt.config.{MqttConfigConstants, MqttSinkConfig, MqttSinkSettings}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.ConfigException
import org.apache.kafka.connect.sink.{SinkRecord, SinkTask}

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}



  override def stop(): Unit = {
    logger.info("Stopping Mqtt sink.")
    writer.foreach(w => w.close)
    progressCounter.empty
  }

  override def flush(map: util.Map[TopicPartition, OffsetAndMetadata]): Unit = {
    require(writer.nonEmpty, "Writer is not set!")
    writer.foreach(w => w.flush)
  }

  override def version: String = manifest.version()
} 
Example 84
Source File: MqttSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.mqtt.sink

import java.util

import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.mqtt.config.{MqttConfigConstants, MqttSinkConfig}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._


class MqttSinkConnector extends SinkConnector with StrictLogging {
  private val configDef = MqttSinkConfig.config
  private var configProps: Option[util.Map[String, String]] = None
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(s"Starting Mqtt sink connector.")
    Helpers.checkInputTopics(MqttConfigConstants.KCQL_CONFIG, props.asScala.toMap)
    configProps = Some(props)
  }

  override def taskClass(): Class[_ <: Task] = classOf[MqttSinkTask]

  override def version(): String = manifest.version()

  override def stop(): Unit = {}

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Setting task configurations for $maxTasks workers.")
    (1 to maxTasks).map(_ => configProps.get).toList.asJava
  }

  override def config(): ConfigDef = configDef
} 
Example 85
Source File: NestedGroupConverter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.landoop.streamreactor.connect.hive.parquet

import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.data.{Field, Schema}
import org.apache.parquet.io.api.{Converter, GroupConverter}

import scala.collection.JavaConverters._

class NestedGroupConverter(schema: Schema,
                           field: Field,
                           parentBuilder: scala.collection.mutable.Map[String, Any])
  extends GroupConverter with StrictLogging {
  private[parquet] val builder = scala.collection.mutable.Map.empty[String, Any]
  private val converters = schema.fields.asScala.map(Converters.get(_, builder)).toIndexedSeq
  override def getConverter(k: Int): Converter = converters(k)
  override def start(): Unit = builder.clear()
  override def end(): Unit = parentBuilder.put(field.name, builder.result)
} 
Example 86
Source File: SortedSetSupport.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.redis.sink.writer

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging

import scala.collection.JavaConverters._

trait SortedSetSupport extends StrictLogging {

  // How to 'score' each message
  def getScoreField(kcqlConfig: Kcql): String = {
    val sortedSetParams = kcqlConfig.getStoredAsParameters.asScala
    val scoreField = if (sortedSetParams.keys.exists(k => k.equalsIgnoreCase("score")))
      sortedSetParams.find { case (k, _) => k.equalsIgnoreCase("score") }.get._2
    else {
      logger.info("You have not defined how to 'score' each message. We'll try to fall back to 'timestamp' field")
      "timestamp"
    }
    scoreField
  }

  // assert(SS.isValid, "The SortedSet definition at Redis accepts only case sensitive alphabetic characters")

} 
Example 87
Source File: RedisWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.redis.sink.writer

import java.io.{File, FileNotFoundException}

import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.redis.sink.config.RedisSinkSettings
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.datamountaineer.streamreactor.connect.sink._
import com.typesafe.scalalogging.StrictLogging
import redis.clients.jedis.Jedis



abstract class RedisWriter extends DbWriter with StrictLogging with ConverterUtil with ErrorHandler {

  var jedis: Jedis = _

  def createClient(sinkSettings: RedisSinkSettings): Unit = {
    val connection = sinkSettings.connectionInfo

    if (connection.isSslConnection) {
        connection.keyStoreFilepath match {
          case Some(path) =>
            if (!new File(path).exists) {
              throw new FileNotFoundException(s"Keystore not found in: $path")
            }

            System.setProperty("javax.net.ssl.keyStorePassword", connection.keyStorePassword.getOrElse(""))
            System.setProperty("javax.net.ssl.keyStore", path)
            System.setProperty("javax.net.ssl.keyStoreType", connection.keyStoreType.getOrElse("jceks"))

          case None =>
        }

        connection.trustStoreFilepath match {
          case Some(path) =>
            if (!new File(path).exists) {
              throw new FileNotFoundException(s"Truststore not found in: $path")
            }

            System.setProperty("javax.net.ssl.trustStorePassword", connection.trustStorePassword.getOrElse(""))
            System.setProperty("javax.net.ssl.trustStore", path)
            System.setProperty("javax.net.ssl.trustStoreType", connection.trustStoreType.getOrElse("jceks"))

          case None =>
        }
    }

    jedis = new Jedis(connection.host, connection.port, connection.isSslConnection)
    connection.password.foreach(p => jedis.auth(p))

    //initialize error tracker
    initialize(sinkSettings.taskRetries, sinkSettings.errorPolicy)
  }

  def close(): Unit = {
    if (jedis != null) {
      jedis.close()
    }
  }

} 
Example 88
Source File: GeoAddSupport.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.redis.sink.writer

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging

import scala.collection.JavaConverters._

trait GeoAddSupport extends StrictLogging {

  def getLongitudeField(kcqlConfig: Kcql): String = {
    getStoredAsParameter("longitudeField", kcqlConfig, "longitude")
  }

  def getLatitudeField(kcqlConfig: Kcql): String = {
    getStoredAsParameter("latitudeField", kcqlConfig, "latitude")
  }

  def getStoredAsParameter(parameterName: String, kcqlConfig: Kcql, defaultValue: String): String = {
    val geoAddParams = kcqlConfig.getStoredAsParameters.asScala
    val parameterValue = if (geoAddParams.keys.exists(k => k.equalsIgnoreCase(parameterName)))
      geoAddParams.find { case (k, _) => k.equalsIgnoreCase(parameterName) }.get._2
    else {
      logger.info(s"You have not defined a $parameterName field. We'll try to fall back to '$defaultValue' field")
      defaultValue
    }
    parameterValue
  }
} 
Example 89
Source File: ProducerConfigFactory.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.pulsar

import java.util.concurrent.TimeUnit

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging
import org.apache.pulsar.client.api.ProducerConfiguration.MessageRoutingMode
import org.apache.pulsar.client.api.{CompressionType, ProducerConfiguration}


object ProducerConfigFactory extends StrictLogging {
  def apply(name: String, kcqls : Set[Kcql]): Map[String, ProducerConfiguration] = {


    kcqls.map(kcql => {
      val conf = new ProducerConfiguration()

      // set batching
      if (kcql.getBatchSize > 0) {
        conf.setBatchingEnabled(true)
        conf.setBatchingMaxMessages(kcql.getBatchSize)

        if (kcql.getWithDelay > 0) {
          conf.setBatchingMaxPublishDelay(kcql.getWithDelay, TimeUnit.MILLISECONDS)
        }
      }

      // set compression type
      if (kcql.getWithCompression != null) {

        val compressionType = kcql.getWithCompression match {
          case com.datamountaineer.kcql.CompressionType.LZ4 => CompressionType.LZ4
          case com.datamountaineer.kcql.CompressionType.ZLIB => CompressionType.ZLIB
          case _ =>
            logger.warn(s"Unknown supported compression type ${kcql.getWithCompression.toString}. Defaulting to LZ4")
            CompressionType.LZ4
        }

        conf.setCompressionType(compressionType)
      }

      // set routing mode
      conf.setMessageRoutingMode(getMessageRouting(kcql))
      conf.setProducerName(name)

      (kcql.getTarget, conf)
    }).toMap
  }

  def getMessageRouting(kcql: Kcql): MessageRoutingMode = {
    // set routing mode
    // match on strings as not enums and Puslar are camelcase
    if (kcql.getWithPartitioner != null) {
      kcql.getWithPartitioner.trim.toUpperCase match {
        case "SINGLEPARTITION" =>
          MessageRoutingMode.SinglePartition

        case "ROUNDROBINPARTITION" =>
          MessageRoutingMode.RoundRobinPartition

        case "CUSTOMPARTITION" =>
          MessageRoutingMode.CustomPartition

        case _ =>
          logger.error(s"Unknown message routing mode ${kcql.getWithType}. Defaulting to SinglePartition")
          MessageRoutingMode.SinglePartition
      }

    } else {
      logger.info(s"Defaulting to SinglePartition message routing mode")
      MessageRoutingMode.SinglePartition
    }
  }
} 
Example 90
Source File: PulsarSourceTask.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.pulsar.source

import java.util
import java.util.UUID

import com.datamountaineer.streamreactor.connect.converters.source.Converter
import com.datamountaineer.streamreactor.connect.pulsar.config.{PulsarConfigConstants, PulsarSourceConfig, PulsarSourceSettings}
import com.datamountaineer.streamreactor.connect.utils.{JarManifest, ProgressCounter}
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}
import org.apache.pulsar.client.api.{ClientConfiguration, PulsarClient}
import org.apache.pulsar.client.impl.auth.AuthenticationTls
import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException

import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}

class PulsarSourceTask extends SourceTask with StrictLogging {
  private val progressCounter = new ProgressCounter
  private var enableProgress: Boolean = false
  private var pulsarManager: Option[PulsarManager] = None
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {

    logger.info(scala.io.Source.fromInputStream(this.getClass.getResourceAsStream("/pulsar-source-ascii.txt")).mkString + s" $version")
    logger.info(manifest.printManifest())

    val conf = if (context.configs().isEmpty) props else context.configs()

    implicit val settings = PulsarSourceSettings(PulsarSourceConfig(conf), props.getOrDefault("tasks.max", "1").toInt)


    val name = conf.getOrDefault("name", s"kafka-connect-pulsar-source-${UUID.randomUUID().toString}")
    val convertersMap = buildConvertersMap(conf, settings)

    val messageConverter = PulsarMessageConverter(
      convertersMap,
      settings.kcql,
      settings.throwOnConversion,
      settings.pollingTimeout,
      settings.batchSize)

    val clientConf = new ClientConfiguration()

    settings.sslCACertFile.foreach(f => {
      clientConf.setUseTls(true)
      clientConf.setTlsTrustCertsFilePath(f)

      val authParams = settings.sslCertFile.map(f => ("tlsCertFile", f)).toMap ++ settings.sslCertKeyFile.map(f => ("tlsKeyFile", f)).toMap
      clientConf.setAuthentication(classOf[AuthenticationTls].getName, authParams.asJava)
    })

    pulsarManager = Some(new PulsarManager(PulsarClient.create(settings.connection, clientConf), name, settings.kcql, messageConverter))
    enableProgress = settings.enableProgress
  }

  def buildConvertersMap(props: util.Map[String, String], settings: PulsarSourceSettings): Map[String, Converter] = {
    settings.sourcesToConverters.map { case (topic, clazz) =>
      logger.info(s"Creating converter instance for $clazz")
      val converter = Try(Class.forName(clazz).newInstance()) match {
        case Success(value) => value.asInstanceOf[Converter]
        case Failure(_) => throw new ConfigException(s"Invalid ${PulsarConfigConstants.KCQL_CONFIG} is invalid. $clazz should have an empty ctor!")
      }
      import scala.collection.JavaConverters._
      converter.initialize(props.asScala.toMap)
      topic -> converter
    }
  }

  
  override def stop(): Unit = {
    logger.info("Stopping Pulsar source.")
    pulsarManager.foreach(_.close())
    progressCounter.empty
  }

  override def version: String = manifest.version()
} 
Example 91
Source File: PulsarSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.pulsar.sink

import java.util

import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.pulsar.config.{PulsarConfigConstants, PulsarSinkConfig}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._

class PulsarSinkConnector extends SinkConnector with StrictLogging {
  private val configDef = PulsarSinkConfig.config
  private var configProps: Option[util.Map[String, String]] = None
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  override def start(props: util.Map[String, String]): Unit = {
    logger.info(s"Starting Pulsar sink connector.")
    Helpers.checkInputTopics(PulsarConfigConstants.KCQL_CONFIG, props.asScala.toMap)
    configProps = Some(props)
  }

  override def taskClass(): Class[_ <: Task] = classOf[PulsarSinkTask]

  override def version(): String = manifest.version()

  override def stop(): Unit = {}

  override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
    logger.info(s"Setting task configurations for $maxTasks workers.")
    (1 to maxTasks).map(_ => configProps.get).toList.asJava
  }

  override def config(): ConfigDef = configDef
} 
Example 92
Source File: ConsumerConfigFactory.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.pulsar

import com.datamountaineer.kcql.Kcql
import com.typesafe.scalalogging.StrictLogging
import org.apache.pulsar.client.api.{ConsumerConfiguration, SubscriptionType}



object ConsumerConfigFactory extends StrictLogging {

  def apply(name: String, kcqls: Set[Kcql]): Map[String, ConsumerConfiguration] = {
    kcqls.map(kcql => {
      val config = new ConsumerConfiguration

      if (kcql.getBatchSize > 0) config.setReceiverQueueSize(kcql.getBatchSize)
      config.setSubscriptionType(getSubscriptionType(kcql))
      config.setConsumerName(name)
      (kcql.getSource, config)
    }).toMap
  }

  def getSubscriptionType(kcql: Kcql): SubscriptionType = {

    if (kcql.getWithSubscription() != null) {
      kcql.getWithSubscription.toUpperCase.trim match {
        case "EXCLUSIVE" =>
          SubscriptionType.Exclusive

        case "FAILOVER" =>
          SubscriptionType.Failover

        case "SHARED" =>
          SubscriptionType.Shared

        case _ =>
          logger.error(s"Unsupported subscription type ${kcql.getWithType} set in WITHTYPE. Defaulting to Failover")
          SubscriptionType.Failover
      }
    } else {
      logger.info("Defaulting to failover subscription type")
      SubscriptionType.Failover
    }
  }
} 
Example 93
Source File: KElasticClient.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.elastic6

import com.datamountaineer.kcql.Kcql
import com.datamountaineer.streamreactor.connect.elastic6.config.ElasticSettings
import com.datamountaineer.streamreactor.connect.elastic6.indexname.CreateIndex.getIndexName
import com.sksamuel.elastic4s.bulk.BulkRequest
import com.sksamuel.elastic4s.http.bulk.BulkResponse
import com.sksamuel.elastic4s.http.{ElasticClient, ElasticNodeEndpoint, ElasticProperties, Response}
import com.sksamuel.elastic4s.mappings.MappingDefinition
import com.typesafe.scalalogging.StrictLogging
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
import org.apache.http.client.config.RequestConfig.Builder
import org.apache.http.impl.client.BasicCredentialsProvider
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder

import scala.concurrent.Future

trait KElasticClient extends AutoCloseable {
  def index(kcql: Kcql)

  def execute(definition: BulkRequest): Future[Any]
}


object KElasticClient extends StrictLogging {

  def createHttpClient(settings: ElasticSettings, endpoints: Seq[ElasticNodeEndpoint]): KElasticClient = {
    if (settings.httpBasicAuthUsername.nonEmpty && settings.httpBasicAuthPassword.nonEmpty) {
      lazy val provider = {
        val provider = new BasicCredentialsProvider
        val credentials = new UsernamePasswordCredentials(settings.httpBasicAuthUsername, settings.httpBasicAuthPassword)
        provider.setCredentials(AuthScope.ANY, credentials)
        provider
      }

      val client: ElasticClient = ElasticClient(
        ElasticProperties(endpoints),
        (requestConfigBuilder: Builder) => requestConfigBuilder,
        (httpClientBuilder: HttpAsyncClientBuilder) => httpClientBuilder.setDefaultCredentialsProvider(provider)
      )
      new HttpKElasticClient(client)
    } else {
      val client: ElasticClient = ElasticClient(ElasticProperties(endpoints))
      new HttpKElasticClient(client)
    }
  }
}

class HttpKElasticClient(client: ElasticClient) extends KElasticClient {

  import com.sksamuel.elastic4s.http.ElasticDsl._

  override def index(kcql: Kcql): Unit = {
    require(kcql.isAutoCreate, s"Auto-creating indexes hasn't been enabled for target:${kcql.getTarget}")

    val indexName = getIndexName(kcql)
    client.execute {
      Option(kcql.getDocType) match {
        case None => createIndex(indexName)
        case Some(documentType) => createIndex(indexName).mappings(MappingDefinition(documentType))
      }
    }
  }

  override def execute(definition: BulkRequest): Future[Response[BulkResponse]] = client.execute(definition)

  override def close(): Unit = client.close()
} 
Example 94
Source File: ElasticJsonWriter.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.elastic6

import java.util

import com.datamountaineer.kcql.{Kcql, WriteModeEnum}
import com.datamountaineer.streamreactor.connect.converters.FieldConverter
import com.datamountaineer.streamreactor.connect.elastic6.config.ElasticSettings
import com.datamountaineer.streamreactor.connect.elastic6.indexname.CreateIndex
import com.datamountaineer.streamreactor.connect.errors.ErrorHandler
import com.datamountaineer.streamreactor.connect.schemas.ConverterUtil
import com.fasterxml.jackson.databind.JsonNode
import com.landoop.sql.Field
import com.sksamuel.elastic4s.Indexable
import com.sksamuel.elastic4s.http.ElasticDsl._
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.connect.sink.SinkRecord

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.Try

class ElasticJsonWriter(client: KElasticClient, settings: ElasticSettings)
  extends ErrorHandler with StrictLogging with ConverterUtil {

  logger.info("Initialising Elastic Json writer")

  //initialize error tracker
  initialize(settings.taskRetries, settings.errorPolicy)

  //create the index automatically if it was set to do so
  settings.kcqls.filter(_.isAutoCreate).foreach(client.index)

  private val topicKcqlMap = settings.kcqls.groupBy(_.getSource)

  private val kcqlMap = new util.IdentityHashMap[Kcql, KcqlValues]()
  settings.kcqls.foreach { kcql =>
    kcqlMap.put(kcql,
      KcqlValues(
        kcql.getFields.asScala.map(FieldConverter.apply),
        kcql.getIgnoredFields.asScala.map(FieldConverter.apply),
        kcql.getPrimaryKeys.asScala.map { pk =>
          val path = Option(pk.getParentFields).map(_.asScala.toVector).getOrElse(Vector.empty)
          path :+ pk.getName
        }
      ))

  }


  implicit object SinkRecordIndexable extends Indexable[SinkRecord] {
    override def json(t: SinkRecord): String = convertValueToJson(t).toString
  }

  
  def autoGenId(record: SinkRecord): String = {
    val pks = Seq(record.topic(), record.kafkaPartition(), record.kafkaOffset())
    pks.mkString(settings.pkJoinerSeparator)
  }

  private case class KcqlValues(fields: Seq[Field],
                                ignoredFields: Seq[Field],
                                primaryKeysPath: Seq[Vector[String]])

}


case object IndexableJsonNode extends Indexable[JsonNode] {
  override def json(t: JsonNode): String = t.toString
} 
Example 95
Source File: ElasticSinkConnector.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.elastic6

import java.util

import com.datamountaineer.streamreactor.connect.config.Helpers
import com.datamountaineer.streamreactor.connect.elastic6.config.{ElasticConfig, ElasticConfigConstants}
import com.datamountaineer.streamreactor.connect.utils.JarManifest
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.sink.SinkConnector

import scala.collection.JavaConverters._

class ElasticSinkConnector extends SinkConnector with StrictLogging {
  private var configProps : Option[util.Map[String, String]] = None
  private val configDef = ElasticConfig.config
  private val manifest = JarManifest(getClass.getProtectionDomain.getCodeSource.getLocation)

  
  override def start(props: util.Map[String, String]): Unit = {
    logger.info(s"Starting Elastic sink task.")
    Helpers.checkInputTopics(ElasticConfigConstants.KCQL, props.asScala.toMap)
    configProps = Some(props)
  }

  override def stop(): Unit = {}
  override def version(): String = manifest.version()
  override def config(): ConfigDef = configDef
} 
Example 96
Source File: StaticContentEndpoints.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.http
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.directives.ContentTypeResolver.Default
import akka.http.scaladsl.server.{Directives, RoutingLog}
import akka.http.scaladsl.settings.{ParserSettings, RoutingSettings}
import akka.stream.Materializer
import com.typesafe.scalalogging.StrictLogging
import scalaz.syntax.show._

import scala.concurrent.Future

object StaticContentEndpoints {
  def all(config: StaticContentConfig)(
      implicit
      routingSettings: RoutingSettings,
      parserSettings: ParserSettings,
      materializer: Materializer,
      routingLog: RoutingLog): HttpRequest PartialFunction Future[HttpResponse] =
    new StaticContentRouter(config)
}

private class StaticContentRouter(config: StaticContentConfig)(
    implicit
    routingSettings: RoutingSettings,
    parserSettings: ParserSettings,
    materializer: Materializer,
    routingLog: RoutingLog)
    extends PartialFunction[HttpRequest, Future[HttpResponse]]
    with StrictLogging {

  private val pathPrefix: Uri.Path = Uri.Path("/" + config.prefix)

  logger.warn(s"StaticContentRouter configured: ${config.shows}")
  logger.warn("DO NOT USE StaticContentRouter IN PRODUCTION, CONSIDER SETTING UP REVERSE PROXY!!!")

  private val fn =
    akka.http.scaladsl.server.Route.asyncHandler(
      Directives.rawPathPrefix(Slash ~ config.prefix)(
        Directives.getFromDirectory(config.directory.getAbsolutePath)
      ))

  override def isDefinedAt(x: HttpRequest): Boolean =
    x.uri.path.startsWith(pathPrefix)

  override def apply(x: HttpRequest): Future[HttpResponse] =
    fn(x)
} 
Example 97
Source File: Main.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen

import com.daml.lf.codegen.conf.Conf
import com.typesafe.scalalogging.StrictLogging

import scala.util.control.NonFatal

object StandaloneMain extends StrictLogging {

  @deprecated("Use codegen font-end: com.daml.codegen.CodegenMain.main", "0.13.23")
  def main(args: Array[String]): Unit =
    try {
      Main.main(args)
    } catch {
      case NonFatal(t) =>
        logger.error(s"Error generating code: {}", t.getMessage)
        sys.exit(-1)
    }
}

object Main {
  @deprecated("Use codegen font-end: com.daml.codegen.CodegenMain.main", "0.13.23")
  def main(args: Array[String]): Unit =
    Conf.parse(args) match {
      case Some(conf) => CodeGenRunner.run(conf)
      case None =>
        throw new IllegalArgumentException(s"Invalid command line arguments: ${args.mkString(" ")}")
    }
} 
Example 98
Source File: JavaBackend.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen.backend.java

import com.daml.lf.codegen.backend.Backend
import com.daml.lf.codegen.backend.java.inner.{ClassForType, DecoderClass}
import com.daml.lf.codegen.conf.Conf
import com.daml.lf.codegen.{InterfaceTrees, ModuleWithContext, NodeWithContext}
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.iface.Interface
import com.squareup.javapoet._
import com.typesafe.scalalogging.StrictLogging
import org.slf4j.MDC

import scala.concurrent.{ExecutionContext, Future}

private[codegen] object JavaBackend extends Backend with StrictLogging {

  override def preprocess(
      interfaces: Seq[Interface],
      conf: Conf,
      packagePrefixes: Map[PackageId, String])(
      implicit ec: ExecutionContext): Future[InterfaceTrees] = {
    val tree = InterfaceTrees.fromInterfaces(interfaces)
    for ((decoderPkg, decoderClassName) <- conf.decoderPkgAndClass) {
      val templateNames = extractTemplateNames(tree, packagePrefixes)
      val decoderFile = JavaFile
        .builder(
          decoderPkg,
          DecoderClass.generateCode(decoderClassName, templateNames)
        )
        .build()
      decoderFile.writeTo(conf.outputDirectory)
    }
    Future.successful(tree)
  }

  private def extractTemplateNames(
      tree: InterfaceTrees,
      packagePrefixes: Map[PackageId, String]) = {
    val prefixes = packagePrefixes.mapValues(_.stripSuffix("."))
    tree.interfaceTrees.flatMap(_.bfs(Vector[ClassName]()) {
      case (res, module: ModuleWithContext) =>
        val templateNames = module.typesLineages
          .collect {
            case t if t.`type`.typ.exists(_.getTemplate.isPresent) =>
              ClassName.bestGuess(inner.fullyQualifiedName(t.identifier, packagePrefixes))
          }
        res ++ templateNames
      case (res, _) => res
    })
  }

  def process(
      nodeWithContext: NodeWithContext,
      conf: Conf,
      packagePrefixes: Map[PackageId, String])(implicit ec: ExecutionContext): Future[Unit] = {
    val prefixes = packagePrefixes.mapValues(_.stripSuffix("."))
    nodeWithContext match {
      case moduleWithContext: ModuleWithContext if moduleWithContext.module.types.nonEmpty =>
        // this is a DAML module that contains type declarations => the codegen will create one file
        Future {
          logger.info(
            s"Generating code for module ${moduleWithContext.lineage.map(_._1).toSeq.mkString(".")}")
          for (javaFile <- createTypeDefinitionClasses(moduleWithContext, prefixes)) {
            logger.info(
              s"Writing ${javaFile.packageName}.${javaFile.typeSpec.name} to directory ${conf.outputDirectory}")
            javaFile.writeTo(conf.outputDirectory)

          }
        }
      case _ =>
        Future.successful(())
    }
  }

  private def createTypeDefinitionClasses(
      moduleWithContext: ModuleWithContext,
      packagePrefixes: Map[PackageId, String]): Iterable[JavaFile] = {
    MDC.put("packageId", moduleWithContext.packageId)
    MDC.put("packageIdShort", moduleWithContext.packageId.take(7))
    MDC.put("moduleName", moduleWithContext.name)
    val typeSpecs = for {
      typeWithContext <- moduleWithContext.typesLineages
      javaFile <- ClassForType(typeWithContext, packagePrefixes)
    } yield {
      javaFile
    }
    MDC.remove("packageId")
    MDC.remove("packageIdShort")
    MDC.remove("moduleName")
    typeSpecs
  }
} 
Example 99
Source File: ClassForType.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen.backend.java.inner
import com.daml.lf.codegen.TypeWithContext
import com.daml.lf.codegen.backend.java.JavaEscaper
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.iface.InterfaceType.{Normal, Template}
import com.daml.lf.iface.{Enum, DefDataType, Record, Variant}
import com.squareup.javapoet.{ClassName, FieldSpec, JavaFile, TypeSpec}
import com.typesafe.scalalogging.StrictLogging
import javax.lang.model.element.Modifier

object ClassForType extends StrictLogging {

  def apply(
      typeWithContext: TypeWithContext,
      packagePrefixes: Map[PackageId, String]): List[JavaFile] = {

    val className =
      ClassName.bestGuess(fullyQualifiedName(typeWithContext.identifier, packagePrefixes))
    val javaPackage = className.packageName()

    typeWithContext.`type`.typ match {

      case Some(Normal(DefDataType(typeVars, record: Record.FWT))) =>
        val typeSpec =
          RecordClass.generate(
            className,
            typeVars.map(JavaEscaper.escapeString),
            record,
            None,
            packagePrefixes)
        List(javaFile(typeWithContext, javaPackage, typeSpec))

      case Some(Normal(DefDataType(typeVars, variant: Variant.FWT))) =>
        val subPackage = className.packageName() + "." + JavaEscaper.escapeString(
          className.simpleName().toLowerCase)
        val (tpe, constructors) =
          VariantClass.generate(
            className,
            subPackage,
            typeVars.map(JavaEscaper.escapeString),
            variant,
            typeWithContext,
            packagePrefixes)
        javaFile(typeWithContext, javaPackage, tpe) ::
          constructors.map(cons => javaFile(typeWithContext, subPackage, cons))

      case Some(Normal(DefDataType(_, enum: Enum))) =>
        List(
          JavaFile
            .builder(javaPackage, EnumClass.generate(className, typeWithContext.identifier, enum))
            .build())

      case Some(Template(record, template)) =>
        val typeSpec =
          TemplateClass.generate(className, record, template, typeWithContext, packagePrefixes)
        List(JavaFile.builder(javaPackage, typeSpec).build())

      case None =>
        // This typeWithContext didn't contain a type itself, but has children nodes
        // which we treat as any other TypeWithContext
        typeWithContext.typesLineages.flatMap(ClassForType(_, packagePrefixes)).toList
    }
  }

  def javaFile(typeWithContext: TypeWithContext, javaPackage: String, typeSpec: TypeSpec) = {
    val withField =
      typeSpec.toBuilder.addField(createPackageIdField(typeWithContext.interface.packageId)).build()
    JavaFile.builder(javaPackage, withField).build()
  }

  private def createPackageIdField(packageId: PackageId): FieldSpec = {
    FieldSpec
      .builder(classOf[String], "_packageId", Modifier.FINAL, Modifier.PUBLIC, Modifier.STATIC)
      .initializer("$S", packageId)
      .build()
  }
} 
Example 100
Source File: VariantRecordMethods.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen.backend.java.inner

import com.daml.ledger.javaapi
import com.daml.lf.codegen.backend.java.ObjectMethods
import com.daml.lf.data.Ref.PackageId
import com.squareup.javapoet._
import com.typesafe.scalalogging.StrictLogging

private[inner] object VariantRecordMethods extends StrictLogging {

  def apply(
      constructorName: String,
      fields: Fields,
      className: TypeName,
      typeParameters: IndexedSeq[String],
      packagePrefixes: Map[PackageId, String]): Vector[MethodSpec] = {
    val constructor = ConstructorGenerator.generateConstructor(fields)

    val conversionMethods = distinctTypeVars(fields, typeParameters).flatMap { params =>
      val toValue = ToValueGenerator.generateToValueForRecordLike(
        params,
        fields,
        packagePrefixes,
        TypeName.get(classOf[javaapi.data.Variant]),
        name =>
          CodeBlock.of(
            "return new $T($S, new $T($L))",
            classOf[javaapi.data.Variant],
            constructorName,
            classOf[javaapi.data.Record],
            name)
      )
      val fromValue = FromValueGenerator.generateFromValueForRecordLike(
        fields,
        className,
        params,
        FromValueGenerator.variantCheck(constructorName, _, _),
        packagePrefixes)
      List(toValue, fromValue)
    }

    Vector(constructor) ++ conversionMethods ++
      ObjectMethods(className.rawType, typeParameters, fields.map(_.javaName))
  }

} 
Example 101
Source File: RecordClass.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen.backend.java.inner
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.iface.Record
import com.squareup.javapoet.{ClassName, TypeName, TypeSpec, TypeVariableName}
import com.typesafe.scalalogging.StrictLogging
import javax.lang.model.element.Modifier

import scala.collection.JavaConverters._

private[inner] object RecordClass extends StrictLogging {

  def generate(
      className: ClassName,
      typeParameters: IndexedSeq[String],
      record: Record.FWT,
      superclass: Option[TypeName],
      packagePrefixes: Map[PackageId, String]): TypeSpec = {
    TrackLineage.of("record", className.simpleName()) {
      logger.info("Start")
      val fields = getFieldsWithTypes(record.fields, packagePrefixes)
      val recordType = TypeSpec
        .classBuilder(className)
        .addModifiers(Modifier.PUBLIC)
        .addTypeVariables(typeParameters.map(TypeVariableName.get).asJava)
        .addFields(RecordFields(fields).asJava)
        .addMethods(RecordMethods(fields, className, typeParameters, packagePrefixes).asJava)
        .build()
      logger.debug("End")
      recordType
    }
  }
} 
Example 102
Source File: VariantRecordClass.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.codegen.backend.java.inner

import com.daml.lf.data.Ref.PackageId
import com.squareup.javapoet._
import com.typesafe.scalalogging.StrictLogging
import javax.lang.model.element.Modifier

import scala.collection.JavaConverters._

private[inner] object VariantRecordClass extends StrictLogging {

  def generate(
      typeParameters: IndexedSeq[String],
      fields: Fields,
      name: String,
      superclass: TypeName,
      packagePrefixes: Map[PackageId, String]): TypeSpec.Builder =
    TrackLineage.of("variant-record", name) {
      logger.info("Start")
      val className = ClassName.bestGuess(name)
      val builder = TypeSpec
        .classBuilder(name)
        .addModifiers(Modifier.PUBLIC)
        .superclass(superclass)
        .addTypeVariables(typeParameters.map(TypeVariableName.get).asJava)
        .addFields(RecordFields(fields).asJava)
        .addMethods(
          VariantRecordMethods(
            name,
            fields,
            className.parameterized(typeParameters),
            typeParameters,
            packagePrefixes).asJava)
      logger.debug("End")
      builder
    }
} 
Example 103
Source File: Main.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.codegen

import java.io.File
import java.nio.file.Path

import ch.qos.logback.classic.Level
import com.daml.lf.codegen.conf.Conf
import com.typesafe.scalalogging.StrictLogging
import org.slf4j.{Logger, LoggerFactory}
import scalaz.Cord

import scala.collection.breakOut

object Main extends StrictLogging {

  private val codegenId = "Scala Codegen"

  @deprecated("Use codegen font-end: com.daml.codegen.CodegenMain.main", "0.13.23")
  def main(args: Array[String]): Unit =
    Conf.parse(args) match {
      case Some(conf) =>
        generateCode(conf)
      case None =>
        throw new IllegalArgumentException(
          s"Invalid ${codegenId: String} command line arguments: ${args.mkString(" "): String}")
    }

  def generateCode(conf: Conf): Unit = conf match {
    case Conf(darMap, outputDir, decoderPkgAndClass, verbosity, roots) =>
      setGlobalLogLevel(verbosity)
      logUnsupportedEventDecoderOverride(decoderPkgAndClass)
      val (dars, packageName) = darsAndOnePackageName(darMap)
      CodeGen.generateCode(dars, packageName, outputDir.toFile, CodeGen.Novel, roots)
  }

  private def setGlobalLogLevel(verbosity: Level): Unit = {
    LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME) match {
      case a: ch.qos.logback.classic.Logger =>
        a.setLevel(verbosity)
        logger.info(s"${codegenId: String} verbosity: ${verbosity.toString}")
      case _ =>
        logger.warn(s"${codegenId: String} cannot set requested verbosity: ${verbosity.toString}")
    }
  }

  private def logUnsupportedEventDecoderOverride(mapping: Option[(String, String)]): Unit =
    mapping.foreach {
      case (a, b) =>
        logger.warn(
          s"${codegenId: String} does not allow overriding Event Decoder, skipping: ${a: String} -> ${b: String}")
    }

  private def darsAndOnePackageName(darMap: Map[Path, Option[String]]): (List[File], String) = {
    val dars: List[File] = darMap.keys.map(_.toFile)(breakOut)
    val uniquePackageNames: Set[String] = darMap.values.collect { case Some(x) => x }(breakOut)
    uniquePackageNames.toSeq match {
      case Seq(packageName) =>
        (dars, packageName)
      case _ =>
        throw new IllegalStateException(
          s"${codegenId: String} expects all dars mapped to the same package name, " +
            s"requested: ${format(darMap): String}")
    }
  }

  private def format(map: Map[Path, Option[String]]): String = {
    val cord = map.foldLeft(Cord("{")) { (str, kv) =>
      str ++ kv._1.toFile.getAbsolutePath ++ "->" ++ kv._2.toString ++ ","
    }
    (cord ++ "}").toString
  }
} 
Example 104
Source File: TriggerRunner.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.engine.trigger

import akka.actor.typed.{Behavior, PostStop}
import akka.actor.typed.scaladsl.AbstractBehavior
import akka.actor.typed.SupervisorStrategy._
import akka.actor.typed.Signal
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.ActorContext
import akka.stream.Materializer
import com.typesafe.scalalogging.StrictLogging
import com.daml.grpc.adapter.ExecutionSequencerFactory

class InitializationHalted(s: String) extends Exception(s) {}
class InitializationException(s: String) extends Exception(s) {}

object TriggerRunner {
  type Config = TriggerRunnerImpl.Config

  trait Message
  final case object Stop extends Message

  def apply(config: Config, name: String)(
      implicit esf: ExecutionSequencerFactory,
      mat: Materializer): Behavior[TriggerRunner.Message] =
    Behaviors.setup(ctx => new TriggerRunner(ctx, config, name))
}

class TriggerRunner(
    ctx: ActorContext[TriggerRunner.Message],
    config: TriggerRunner.Config,
    name: String)(implicit esf: ExecutionSequencerFactory, mat: Materializer)
    extends AbstractBehavior[TriggerRunner.Message](ctx)
    with StrictLogging {

  import TriggerRunner.{Message, Stop}

  // Spawn a trigger runner impl. Supervise it. Stop immediately on
  // initialization halted exceptions, retry any initialization or
  // execution failure exceptions.
  private val child =
    ctx.spawn(
      Behaviors
        .supervise(
          Behaviors
            .supervise(TriggerRunnerImpl(config))
            .onFailure[InitializationHalted](stop)
        )
        .onFailure(
          restartWithBackoff(
            config.restartConfig.minRestartInterval,
            config.restartConfig.maxRestartInterval,
            config.restartConfig.restartIntervalRandomFactor)),
      name
    )

  override def onMessage(msg: Message): Behavior[Message] =
    Behaviors.receiveMessagePartial[Message] {
      case Stop =>
        Behaviors.stopped // Automatically stops the child actor if running.
    }

  override def onSignal: PartialFunction[Signal, Behavior[Message]] = {
    case PostStop =>
      logger.info(s"Trigger $name stopped")
      this
  }

} 
Example 105
Source File: CommitMarkerOffsetsActor.scala    From kmq   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.kmq.redelivery

import akka.actor.Actor
import com.softwaremill.kmq.KafkaClients
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.ByteArrayDeserializer

import scala.collection.JavaConverters._
import scala.concurrent.duration._

class CommitMarkerOffsetsActor(markerTopic: String, clients: KafkaClients) extends Actor with StrictLogging {

  private val consumer = clients.createConsumer(null, classOf[ByteArrayDeserializer], classOf[ByteArrayDeserializer])

  private var toCommit: Map[Partition, Offset] = Map()

  import context.dispatcher

  override def preStart(): Unit = {
    logger.info("Started commit marker offsets actor")
  }

  override def postStop(): Unit = {
    try consumer.close()
    catch {
      case e: Exception => logger.error("Cannot close commit offsets consumer", e)
    }

    logger.info("Stopped commit marker offsets actor")
  }

  override def receive: Receive = {
    case CommitOffset(p, o) =>
      // only updating if the current offset is smaller
      if (toCommit.get(p).fold(true)(_ < o))
        toCommit += p -> o

    case DoCommit =>
      try {
        commitOffsets()
        toCommit = Map()
      } finally context.system.scheduler.scheduleOnce(1.second, self, DoCommit)
  }

  private def commitOffsets(): Unit = if (toCommit.nonEmpty) {
    consumer.commitSync(toCommit.map { case (partition, offset) =>
      (new TopicPartition(markerTopic, partition), new OffsetAndMetadata(offset))
    }.asJava)

    logger.debug(s"Committed marker offsets: $toCommit")
  }
} 
Example 106
Source File: RedeliveryActors.scala    From kmq   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.kmq.redelivery

import java.io.Closeable
import java.util.Collections

import akka.actor.{ActorSystem, Props}
import com.softwaremill.kmq.{KafkaClients, KmqConfig}
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.collection.JavaConverters._

object RedeliveryActors extends StrictLogging {
  def start(clients: KafkaClients, config: KmqConfig): Closeable = {
    val system = ActorSystem("kmq-redelivery")

    val consumeMakersActor = system.actorOf(Props(new ConsumeMarkersActor(clients, config)), "consume-markers-actor")
    consumeMakersActor ! DoConsume

    logger.info("Started redelivery actors")

    new Closeable {
      override def close(): Unit = Await.result(system.terminate(), 1.minute)
    }
  }
} 
Example 107
Source File: RedeliverActor.scala    From kmq   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.kmq.redelivery

import akka.actor.Actor
import com.softwaremill.kmq.MarkerKey
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.duration._

class RedeliverActor(p: Partition, redeliverer: Redeliverer) extends Actor with StrictLogging {

  private var toRedeliver: List[MarkerKey] = Nil

  import context.dispatcher

  override def preStart(): Unit = {
    logger.info(s"${self.path} Started redeliver actor for partition $p")
  }

  override def postStop(): Unit = {
    try redeliverer.close()
    catch {
      case e: Exception => logger.error(s"Cannot close redeliverer for partition $p", e)
    }
    
    logger.info(s"${self.path} Stopped redeliver actor for partition $p")
  }

  override def receive: Receive = {
    case RedeliverMarkers(m) =>
      toRedeliver ++= m

    case DoRedeliver =>
      val hadRedeliveries = toRedeliver.nonEmpty
      try {
        redeliverer.redeliver(toRedeliver)
        toRedeliver = Nil
      } finally {
        if (hadRedeliveries) {
          self ! DoRedeliver
        } else {
          context.system.scheduler.scheduleOnce(1.second, self, DoRedeliver)
        }
      }
  }
} 
Example 108
Source File: Stopwatch.scala    From keycloak-benchmark   with Apache License 2.0 5 votes vote down vote up
package io.gatling.keycloak

import com.typesafe.scalalogging.StrictLogging
import io.gatling.core.action.{UserEnd, Chainable}
import io.gatling.core.akka.GatlingActorSystem
import io.gatling.core.result.message.{OK, KO, Status}
import io.gatling.core.result.writer.{DataWriter, DataWriterClient}
import io.gatling.core.session.Session
import io.gatling.core.util.TimeHelper
import io.gatling.core.validation.{Validation, Success, Failure}


object Stopwatch extends StrictLogging {
  @volatile var recording: Boolean = true;
  GatlingActorSystem.instance.registerOnTermination(() => recording = true)

  def apply[T](f: () => T): Result[T] = {
    val start = TimeHelper.nowMillis
    try {
      val result = f()
      Result(Success(result), OK, start, TimeHelper.nowMillis, false)
    } catch {
      case ie: InterruptedException => {
        Result(Failure("Interrupted"), KO, start, start, true)
      }
      case e: Throwable => {
        Stopwatch.log.error("Operation failed with exception", e)
        Result(Failure(e.toString), KO, start, TimeHelper.nowMillis, false)
      }
    }
  }

  def log = logger;
}

case class Result[T](
                      val value: Validation[T],
                      val status: Status,
                      val startTime: Long,
                      val endTime: Long,
                      val interrupted: Boolean
) {
  def check(check: T => Boolean, fail: T => String): Result[T] = {
     value match {
       case Success(v) =>
         if (!check(v)) {
           Result(Failure(fail(v)), KO, startTime, endTime, interrupted);
         } else {
           this
         }
       case _ => this
     }
  }

  def isSuccess =
    value match {
      case Success(_) => true
      case _ => false
    }

  private def record(client: DataWriterClient, session: Session, name: String): Validation[T] = {
    if (!interrupted && Stopwatch.recording) {
      client.writeRequestData(session, name, startTime, startTime, endTime, endTime, status)
    }
    value
  }

  def recordAndStopOnFailure(client: DataWriterClient with Chainable, session: Session, name: String): Validation[T] = {
    val validation = record(client, session, name)
    validation.onFailure(message => {
        Stopwatch.log.error(s"'${client.self.path.name}', ${session.userId} failed to execute: $message")
        UserEnd.instance ! session.markAsFailed
    })
    validation
  }

  def recordAndContinue(client: DataWriterClient with Chainable, session: Session, name: String): Unit = {
    // can't specify follow function as default arg since it uses another parameter
    recordAndContinue(client, session, name, _ => session);
  }

  def recordAndContinue(client: DataWriterClient with Chainable, session: Session, name: String, follow: T => Session): Unit = {
    // 'follow' intentionally does not get session as arg, since caller site already has the reference
    record(client, session, name) match {
      case Success(value) => try {
        client.next ! follow(value)
      } catch {
        case t: Throwable => {
          Stopwatch.log.error(s"'${client.self.path.name}' failed processing", t)
          UserEnd.instance ! session.markAsFailed
      }
    }
      case Failure(message) => {
        Stopwatch.log.error(s"'${client.self.path.name}' failed to execute: $message")
        UserEnd.instance ! session.markAsFailed
      }
    }
  }
} 
Example 109
Source File: Worker.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.local.miner

import java.util.Date

import akka.actor.{Actor, ActorRef}
import encry.EncryApp._

import scala.concurrent.duration._
import encry.consensus.{CandidateBlock, ConsensusSchemeReaders}
import encry.local.miner.Miner.MinedBlock
import encry.local.miner.Worker.{MineBlock, NextChallenge}
import java.text.SimpleDateFormat

import com.typesafe.scalalogging.StrictLogging
import org.encryfoundation.common.utils.constants.TestNetConstants

class Worker(myIdx: Int, numberOfWorkers: Int, miner: ActorRef) extends Actor with StrictLogging {

  val sdf: SimpleDateFormat = new SimpleDateFormat("HH:mm:ss")
  var challengeStartTime: Date = new Date(System.currentTimeMillis())

  val initialNonce: Long = Long.MaxValue / numberOfWorkers * myIdx

  override def preRestart(reason: Throwable, message: Option[Any]): Unit =
    logger.warn(s"Worker $myIdx is restarting because of: $reason")

  override def receive: Receive = {
    case MineBlock(candidate: CandidateBlock, nonce: Long) =>
      logger.info(s"Trying nonce: $nonce. Start nonce is: $initialNonce. " +
        s"Iter qty: ${nonce - initialNonce + 1} on worker: $myIdx with diff: ${candidate.difficulty}")
      ConsensusSchemeReaders
        .consensusScheme.verifyCandidate(candidate, nonce)
        .fold(
          e => {
            self ! MineBlock(candidate, nonce + 1)
            logger.info(s"Mining failed cause: $e")
          },
          block => {
            logger.info(s"New block is found: (${block.header.height}, ${block.header.encodedId}, ${block.payload.txs.size} " +
              s"on worker $self at ${sdf.format(new Date(System.currentTimeMillis()))}. Iter qty: ${nonce - initialNonce + 1}")
            miner ! MinedBlock(block, myIdx)
          })
    case NextChallenge(candidate: CandidateBlock) =>
      challengeStartTime = new Date(System.currentTimeMillis())
      logger.info(s"Start next challenge on worker: $myIdx at height " +
        s"${candidate.parentOpt.map(_.height + 1).getOrElse(TestNetConstants.PreGenesisHeight.toString)} at ${sdf.format(challengeStartTime)}")
      self ! MineBlock(candidate, Long.MaxValue / numberOfWorkers * myIdx)
  }

}

object Worker {

  case class NextChallenge(candidateBlock: CandidateBlock)

  case class MineBlock(candidateBlock: CandidateBlock, nonce: Long)

} 
Example 110
Source File: ModifiersToNetworkUtils.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.network

import HeaderProto.HeaderProtoMessage
import PayloadProto.PayloadProtoMessage
import com.typesafe.scalalogging.StrictLogging
import encry.modifiers.history.{ HeaderUtils, PayloadUtils }
import encry.settings.EncryAppSettings
import encry.view.history.History
import org.encryfoundation.common.modifiers.PersistentModifier
import org.encryfoundation.common.modifiers.history.{ Header, HeaderProtoSerializer, Payload, PayloadProtoSerializer }
import org.encryfoundation.common.utils.TaggedTypes.ModifierTypeId
import cats.syntax.either._
import encry.modifiers.history.HeaderUtils.PreSemanticValidationException
import scala.util.{ Failure, Try }

object ModifiersToNetworkUtils extends StrictLogging {

  def toProto(modifier: PersistentModifier): Array[Byte] = modifier match {
    case m: Header  => HeaderProtoSerializer.toProto(m).toByteArray
    case m: Payload => PayloadProtoSerializer.toProto(m).toByteArray
    case m          => throw new RuntimeException(s"Try to serialize unknown modifier: $m to proto.")
  }

  def fromProto(modType: ModifierTypeId, bytes: Array[Byte]): Try[PersistentModifier] =
    Try(modType match {
      case Header.modifierTypeId  => HeaderProtoSerializer.fromProto(HeaderProtoMessage.parseFrom(bytes))
      case Payload.modifierTypeId => PayloadProtoSerializer.fromProto(PayloadProtoMessage.parseFrom(bytes))
      case m                      => Failure(new RuntimeException(s"Try to deserialize unknown modifier: $m from proto."))
    }).flatten

  def isSyntacticallyValid(modifier: PersistentModifier, modifierIdSize: Int): Boolean = modifier match {
    case h: Header  => HeaderUtils.syntacticallyValidity(h, modifierIdSize).isSuccess
    case p: Payload => PayloadUtils.syntacticallyValidity(p, modifierIdSize).isSuccess
    case _          => true
  }

  def isPreSemanticValidation(modifier: PersistentModifier,
                              history: History,
                              settings: EncryAppSettings): Either[PreSemanticValidationException, Unit] =
    modifier match {
      case h: Header => HeaderUtils.preSemanticValidation(h, history, settings)
      case _         => ().asRight[PreSemanticValidationException]
    }
} 
Example 111
Source File: PrioritiesCalculator.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.network

import java.net.InetSocketAddress

import com.typesafe.scalalogging.StrictLogging
import encry.network.PrioritiesCalculator.PeersPriorityStatus
import encry.network.PrioritiesCalculator.PeersPriorityStatus.PeersPriorityStatus._
import encry.network.PrioritiesCalculator.PeersPriorityStatus._
import encry.settings.NetworkSettings

import scala.concurrent.duration._

final case class PrioritiesCalculator(networkSettings: NetworkSettings,
                                      private val peersNetworkStatistic: Map[InetSocketAddress, (Requested, Received)])
  extends StrictLogging {

  val updatingStatisticTime: FiniteDuration = (networkSettings.deliveryTimeout._1 * networkSettings.maxDeliveryChecks).seconds

  def incrementRequest(peer: InetSocketAddress): PrioritiesCalculator = {
    val (requested, received): (Requested, Received) = peersNetworkStatistic.getOrElse(peer, (Requested(), Received()))
    val newRequested: Requested = requested.increment
    logger.debug(s"Updating request parameter from $peer. Old is ($requested, $received). New one is: ($newRequested, $received)")
    PrioritiesCalculator(networkSettings, peersNetworkStatistic.updated(peer, (newRequested, received)))
  }

  def incrementReceive(peer: InetSocketAddress): PrioritiesCalculator = {
    val (requested, received): (Requested, Received) = peersNetworkStatistic.getOrElse(peer, (Requested(), Received()))
    val newReceived: Received = received.increment
    logger.debug(s"Updating received parameter from $peer. Old is ($requested, $received). New one is: ($requested, $newReceived)")
    PrioritiesCalculator(networkSettings, peersNetworkStatistic.updated(peer, (requested, newReceived)))
  }

  def decrementRequest(peer: InetSocketAddress): PrioritiesCalculator = {
    val (requested, received): (Requested, Received) = peersNetworkStatistic.getOrElse(peer, (Requested(), Received()))
    val newRequested: Requested = requested.decrement
    logger.debug(s"Decrement request parameter from $peer. Old is ($requested, $received). New one is: ($newRequested, $received)")
    PrioritiesCalculator(networkSettings, peersNetworkStatistic.updated(peer, (newRequested, received)))
  }

  def incrementRequestForNModifiers(peer: InetSocketAddress, modifiersQty: Int): PrioritiesCalculator = {
    val (requested, received): (Requested, Received) = peersNetworkStatistic.getOrElse(peer, (Requested(), Received()))
    val newRequested: Requested = requested.incrementForN(modifiersQty)
    logger.debug(s"Updating request parameter from $peer. Old is ($requested, $received). New one is: ($newRequested, $received)")
    PrioritiesCalculator(networkSettings, peersNetworkStatistic.updated(peer, (newRequested, received)))
  }

  def accumulatePeersStatistic: (Map[InetSocketAddress, PeersPriorityStatus], PrioritiesCalculator) = {
    val updatedStatistic: Map[InetSocketAddress, PeersPriorityStatus] = peersNetworkStatistic.map {
      case (peer, (requested, received)) =>
        logger.info(s"peer: $peer: received: $received, requested: $requested")
        val priority: PeersPriorityStatus = PeersPriorityStatus.calculateStatuses(received, requested)
        peer -> priority
    }
    logger.info(s"Accumulated peers statistic. Current stats are: ${updatedStatistic.mkString(",")}")
    (updatedStatistic, PrioritiesCalculator(networkSettings))
  }
}

object PrioritiesCalculator {

  final case class AccumulatedPeersStatistic(statistic: Map[InetSocketAddress, PeersPriorityStatus])

  object PeersPriorityStatus {

    sealed trait PeersPriorityStatus
    object PeersPriorityStatus {
      case object HighPriority extends PeersPriorityStatus
      case object LowPriority extends PeersPriorityStatus
      case object InitialPriority extends PeersPriorityStatus
      case object BadNode extends PeersPriorityStatus
    }

    final case class Received(received: Int = 0) extends AnyVal {
      def increment: Received = Received(received + 1)
    }

    final case class Requested(requested: Int = 0) extends AnyVal {
      def increment: Requested = Requested(requested + 1)

      def decrement: Requested = Requested(requested - 1)

      def incrementForN(n: Int): Requested = Requested(requested + n)
    }

    private val criterionForHighP: Double = 0.75
    private val criterionForLowP: Double  = 0.50

    def calculateStatuses(res: Received, req: Requested): PeersPriorityStatus =
      res.received.toDouble / req.requested match {
        case t if t >= criterionForHighP => HighPriority
        case t if t >= criterionForLowP  => LowPriority
        case _                           => BadNode
      }
  }

  def apply(networkSettings: NetworkSettings): PrioritiesCalculator =
    PrioritiesCalculator(networkSettings, Map.empty[InetSocketAddress, (Requested, Received)])
} 
Example 112
Source File: ConnectedPeersCollection.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.network

import java.net.InetSocketAddress
import com.typesafe.scalalogging.StrictLogging
import encry.consensus.HistoryConsensus.{HistoryComparisonResult, Unknown}
import encry.network.ConnectedPeersCollection.{LastUptime, PeerInfo}
import encry.network.PeerConnectionHandler.{ConnectedPeer, ConnectionType, Outgoing}
import encry.network.PrioritiesCalculator.PeersPriorityStatus.PeersPriorityStatus.InitialPriority
import encry.network.PrioritiesCalculator.PeersPriorityStatus.PeersPriorityStatus

final case class ConnectedPeersCollection(private val peers: Map[InetSocketAddress, PeerInfo]) extends StrictLogging {

  val size: Int = peers.size

  def contains(peer: InetSocketAddress): Boolean = peers.contains(peer)

  def initializePeer(cp: ConnectedPeer): ConnectedPeersCollection = ConnectedPeersCollection(peers.updated(
    cp.socketAddress, PeerInfo(Unknown, InitialPriority, cp, Outgoing, LastUptime(0))
  ))

  def removePeer(address: InetSocketAddress): ConnectedPeersCollection = ConnectedPeersCollection(peers - address)

  def updatePriorityStatus(stats: Map[InetSocketAddress, PeersPriorityStatus]): ConnectedPeersCollection =
    ConnectedPeersCollection(updateK(stats, updateStatus))

  def updateHistoryComparisonResult(hcr: Map[InetSocketAddress, HistoryComparisonResult]): ConnectedPeersCollection =
    ConnectedPeersCollection(updateK(hcr, updateComparisonResult))

  def updateLastUptime(lup: Map[InetSocketAddress, LastUptime]): ConnectedPeersCollection =
    ConnectedPeersCollection(updateK(lup, updateUptime))

  def collect[T](p: (InetSocketAddress, PeerInfo) => Boolean,
                 f: (InetSocketAddress, PeerInfo) => T): Seq[T] = peers
    .collect { case (peer, info) if p(peer, info) => f(peer, info) }
    .toSeq

  def getAll: Map[InetSocketAddress, PeerInfo] = peers

  private def updateK[T](elems: Map[InetSocketAddress, T], f: (PeerInfo, T) => PeerInfo): Map[InetSocketAddress, PeerInfo] = {
    val newValue: Map[InetSocketAddress, PeerInfo] = for {
      (key, value) <- elems
      oldValue     <- peers.get(key)
    } yield key -> f(oldValue, value)
    peers ++ newValue
  }

  private def updateStatus: (PeerInfo, PeersPriorityStatus) => PeerInfo = (i, p) => i.copy(peerPriorityStatus = p)
  private def updateComparisonResult: (PeerInfo, HistoryComparisonResult) => PeerInfo = (i, h) => i.copy(historyComparisonResult = h)
  private def updateUptime: (PeerInfo, LastUptime) => PeerInfo = (i, u) => i.copy(lastUptime = u)

}

object ConnectedPeersCollection {

  final case class LastUptime(time: Long) extends AnyVal

  final case class PeerInfo(historyComparisonResult: HistoryComparisonResult,
                            peerPriorityStatus: PeersPriorityStatus,
                            connectedPeer: ConnectedPeer,
                            connectionType: ConnectionType,
                            lastUptime: LastUptime)

  def apply(): ConnectedPeersCollection = ConnectedPeersCollection(Map.empty[InetSocketAddress, PeerInfo])
} 
Example 113
Source File: SnapshotDownloadControllerStorageAPI.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.fast.sync

import com.typesafe.scalalogging.StrictLogging
import encry.settings.EncryAppSettings
import encry.view.fast.sync.SnapshotHolder.SnapshotManifest.ChunkId
import org.encryfoundation.common.utils.Algos
import org.iq80.leveldb.DB

trait SnapshotDownloadControllerStorageAPI extends DBTryCatchFinallyProvider with StrictLogging {

  val storage: DB

  val settings: EncryAppSettings

  def nextGroupKey(n: Int): Array[Byte] = Algos.hash(s"next_group_key_$n")

  
  def getNextForRequest(groupNumber: Int): Either[Throwable, List[ChunkId]] =
    readWrite(
      (batch, readOptions, _) => {
        logger.debug(s"Going to get next group for request with number $groupNumber.")
        val res = storage.get(nextGroupKey(groupNumber), readOptions)
        val buffer: List[ChunkId] =
          if (res != null) {
            val value = res.grouped(32).toList.map(ChunkId @@ _)
            logger.debug(s"Gotten group is non empty. Elements number is ${value.size}.")
            logger.debug(s"First element of the group is: ${value.headOption.map(Algos.encode)}")
            batch.delete(nextGroupKey(groupNumber))
            value
          } else {
            logger.debug(s"Requested group is null")
            throw new Exception("Inconsistent snapshot download controller db state!")
          }
        buffer
      }
    )
} 
Example 114
Source File: DBTryCatchFinallyProvider.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.fast.sync

import com.typesafe.scalalogging.StrictLogging
import org.iq80.leveldb.{ DB, DBIterator, ReadOptions, WriteBatch }
import cats.syntax.either._

trait DBTryCatchFinallyProvider extends StrictLogging {

  val storage: DB

  def readWrite[Output](
    f: (WriteBatch, ReadOptions, DBIterator) => Output
  ): Either[Throwable, Output] = {
    val snapshot             = storage.getSnapshot
    val readOptions          = new ReadOptions().snapshot(snapshot)
    val batch: WriteBatch    = storage.createWriteBatch()
    val iterator: DBIterator = storage.iterator(readOptions)
    try {
      val output = f(batch, readOptions, iterator)
      storage.write(batch)
      output.asRight[Throwable]
    } catch {
      case error: Throwable =>
        logger.info(s"Error has occurred $error")
        error.asLeft[Output]
    } finally {
      iterator.close()
      readOptions.snapshot().close()
      batch.close()
    }
  }
} 
Example 115
Source File: SnapshotProcessorStorageAPI.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.fast.sync

import SnapshotChunkProto.SnapshotChunkMessage
import SnapshotManifestProto.SnapshotManifestProtoMessage
import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.{ StorageKey, StorageValue }
import encry.view.fast.sync.SnapshotHolder.{ SnapshotManifest, SnapshotManifestSerializer }
import org.encryfoundation.common.utils.Algos
import scala.util.Try

trait SnapshotProcessorStorageAPI extends StrictLogging {

  val storage: VersionalStorage

  def getManifestId(id: StorageKey): Option[Array[Byte]] = storage.get(id)

  def actualManifestId: Option[Array[Byte]] = getManifestId(ActualManifestKey)

  def manifestById(id: StorageKey): Option[SnapshotManifest] =
    storage
      .get(StorageKey @@ id)
      .flatMap(bytes => SnapshotManifestSerializer.fromProto(SnapshotManifestProtoMessage.parseFrom(bytes)).toOption)

  def actualManifest: Option[SnapshotManifest] =
    actualManifestId
      .flatMap(id => manifestById(StorageKey !@@ id))

  def potentialManifestsIds: Seq[Array[Byte]] =
    storage
      .get(PotentialManifestsIdsKey)
      .map(_.grouped(32).toSeq)
      .getOrElse(Seq.empty)

  def manifestBytesById(id: StorageKey): Option[StorageValue] = storage.get(id)

  def getChunkById(chunkId: Array[Byte]): Option[SnapshotChunkMessage] =
    storage.get(StorageKey @@ chunkId).flatMap(e => Try(SnapshotChunkMessage.parseFrom(e)).toOption)

  val ActualManifestKey: StorageKey        = StorageKey @@ Algos.hash("actual_manifest_key")
  val PotentialManifestsIdsKey: StorageKey = StorageKey @@ Algos.hash("potential_manifests_ids_key")
} 
Example 116
Source File: ShadowNode.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.state.avlTree

import NodeMsg.NodeProtoMsg.NodeTypes.ShadowNodeProto
import cats.Monoid
import com.google.protobuf.ByteString
import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.StorageKey
import encry.view.state.avlTree.utils.implicits.{Hashable, Serializer}
import io.iohk.iodb.ByteArrayWrapper
import org.encryfoundation.common.utils.Algos

import scala.util.Try

abstract class ShadowNode[K: Serializer: Monoid, V: Serializer: Monoid] extends Node[K, V] {
  def restoreFullNode(storage: VersionalStorage): Node[K, V]
  final override def selfInspection = this
}

object ShadowNode {

  def nodeToShadow[K: Serializer : Hashable : Monoid, V: Serializer : Monoid](node: Node[K, V]): ShadowNode[K, V] = node match {
    case internal: InternalNode[K, V] =>
      NonEmptyShadowNode(internal.hash, internal.height, internal.balance, internal.key)
    case leaf: LeafNode[K, V] =>
      NonEmptyShadowNode(leaf.hash, leaf.height, leaf.balance, leaf.key)
    case _: EmptyNode[K, V] =>
      new EmptyShadowNode[K, V]
    case anotherShadow: ShadowNode[K, V] => anotherShadow
  }

  def childsToShadowNode[K: Serializer : Hashable : Monoid, V: Serializer : Monoid](node: Node[K, V]): Node[K, V] = node match {
    case internal: InternalNode[K, V] =>
      internal.copy(
        leftChild = nodeToShadow(internal.leftChild),
        rightChild = nodeToShadow(internal.rightChild),
      )
    case _ => node
  }
} 
Example 117
Source File: NonEmptyShadowNode.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.state.avlTree

import NodeMsg.NodeProtoMsg.NodeTypes.{NonEmptyShadowNodeProto, ShadowNodeProto}
import cats.Monoid
import com.google.protobuf.ByteString
import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.StorageKey
import encry.view.state.avlTree.utils.implicits.{Hashable, Serializer}
import org.encryfoundation.common.utils.Algos

import scala.util.Try

case class NonEmptyShadowNode[K: Serializer: Hashable, V: Serializer](nodeHash: Array[Byte],
                                                                      height: Int,
                                                                      balance: Int,
                                                                      key: K)
                                                                     (implicit kM: Monoid[K], vM: Monoid[V]) extends ShadowNode[K, V] with StrictLogging {

  override val value: V = vM.empty

  override val hash = nodeHash

  def restoreFullNode(storage: VersionalStorage): Node[K, V] = if (nodeHash.nonEmpty) {
    NodeSerilalizer.fromBytes[K, V](
      {
        val res = storage.get(StorageKey @@ AvlTree.nodeKey(hash))
        if (res.isEmpty) logger.info(s"Empty node at key: ${Algos.encode(hash)}")
        res.get
      }
    )
  } else EmptyNode[K, V]()

  def tryRestore(storage: VersionalStorage): Option[Node[K, V]] =
    Try(restoreFullNode(storage)).toOption

  override def toString: String = s"ShadowNode(Hash:${Algos.encode(hash)}, height: ${height}, balance: ${balance})"
}

object NonEmptyShadowNode {

  def toProto[K: Serializer, V](node: NonEmptyShadowNode[K, V]): NonEmptyShadowNodeProto = NonEmptyShadowNodeProto()
    .withHeight(node.height)
    .withHash(ByteString.copyFrom(node.hash))
    .withBalance(node.balance)
    .withKey(ByteString.copyFrom(implicitly[Serializer[K]].toBytes(node.key)))

  def fromProto[K: Hashable : Monoid : Serializer, V: Monoid : Serializer](protoNode: NonEmptyShadowNodeProto): Try[NonEmptyShadowNode[K, V]] = Try {
    NonEmptyShadowNode(
      nodeHash = protoNode.hash.toByteArray,
      height = protoNode.height,
      balance = protoNode.balance,
      key = implicitly[Serializer[K]].fromBytes(protoNode.key.toByteArray)
    )
  }
} 
Example 118
Source File: HistoryDBApi.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.history

import com.google.common.primitives.Ints
import com.typesafe.scalalogging.StrictLogging
import encry.settings.EncryAppSettings
import encry.storage.VersionalStorage.StorageKey
import encry.view.history.storage.HistoryStorage
import org.encryfoundation.common.modifiers.history.{Block, Header, Payload}
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.common.utils.TaggedTypes.{Height, ModifierId, ModifierTypeId}
import scorex.crypto.hash.Digest32

import scala.reflect.ClassTag

trait HistoryDBApi extends StrictLogging {

  val settings: EncryAppSettings

  val historyStorage: HistoryStorage

  lazy val BestHeaderKey: StorageKey =
    StorageKey @@ Array.fill(settings.constants.DigestLength)(Header.modifierTypeId.untag(ModifierTypeId))
  lazy val BestBlockKey: StorageKey =
    StorageKey @@ Array.fill(settings.constants.DigestLength)(-1: Byte)

  private def getModifierById[T: ClassTag](id: ModifierId): Option[T] = historyStorage
    .modifierById(id)
    .collect { case m: T => m }

  def getHeightByHeaderIdDB(id: ModifierId): Option[Int] = historyStorage
    .get(headerHeightKey(id))
    .map(Ints.fromByteArray)

  def getHeaderByIdDB(id: ModifierId): Option[Header] = getModifierById[Header](id)
  def getPayloadByIdDB(pId: ModifierId): Option[Payload] = getModifierById[Payload](pId)
  def getBlockByHeaderDB(header: Header): Option[Block] = getModifierById[Payload](header.payloadId)
    .map(payload => Block(header, payload))
  def getBlockByHeaderIdDB(id: ModifierId): Option[Block] = getHeaderByIdDB(id)
    .flatMap(h => getModifierById[Payload](h.payloadId).map(p => Block(h, p)))

  def getBestHeaderId: Option[ModifierId] = historyStorage.get(BestHeaderKey).map(ModifierId @@ _)
  def getBestHeaderDB: Option[Header] = getBestHeaderId.flatMap(getHeaderByIdDB)
  def getBestHeaderHeightDB: Int = getBestHeaderId
    .flatMap(getHeightByHeaderIdDB)
    .getOrElse(settings.constants.PreGenesisHeight)

  def getBestBlockId: Option[ModifierId] = historyStorage.get(BestBlockKey).map(ModifierId @@ _)
  def getBestBlockDB: Option[Block] = getBestBlockId.flatMap(getBlockByHeaderIdDB)
  def getBestBlockHeightDB: Int = getBestBlockId
    .flatMap(getHeightByHeaderIdDB)
    .getOrElse(settings.constants.PreGenesisHeight)

  def modifierBytesByIdDB(id: ModifierId): Option[Array[Byte]] = historyStorage.modifiersBytesById(id)

  def isModifierDefined(id: ModifierId): Boolean = historyStorage.containsMod(id)

  //todo probably rewrite with indexes collection
  def lastBestBlockHeightRelevantToBestChain(probablyAt: Int): Option[Int] = (for {
    headerId <- getBestHeaderIdAtHeightDB(probablyAt)
    header   <- getHeaderByIdDB(headerId) if isModifierDefined(header.payloadId)
  } yield header.height).orElse(lastBestBlockHeightRelevantToBestChain(probablyAt - 1))

  def headerIdsAtHeightDB(height: Int): Option[Seq[ModifierId]] = historyStorage
    .get(heightIdsKey(height))
    .map(_.grouped(32).map(ModifierId @@ _).toSeq)

  def getBestHeaderIdAtHeightDB(h: Int): Option[ModifierId] = headerIdsAtHeightDB(h).flatMap(_.headOption)

  def getBestHeaderAtHeightDB(h: Int): Option[Header] = getBestHeaderIdAtHeightDB(h).flatMap(getHeaderByIdDB)

  def isInBestChain(h: Header): Boolean = getBestHeaderIdAtHeightDB(h.height)
    .exists(_.sameElements(h.id))

  def isInBestChain(id: ModifierId): Boolean = heightOf(id)
    .flatMap(getBestHeaderIdAtHeightDB)
    .exists(_.sameElements(id))

  def getBestHeadersChainScore: BigInt = getBestHeaderId.flatMap(scoreOf).getOrElse(BigInt(0)) //todo ?.getOrElse(BigInt(0))?

  def scoreOf(id: ModifierId): Option[BigInt] = historyStorage
    .get(headerScoreKey(id))
    .map(d => BigInt(d))

  def heightOf(id: ModifierId): Option[Height] = historyStorage
    .get(headerHeightKey(id))
    .map(d => Height @@ Ints.fromByteArray(d))

  def heightIdsKey(height: Int): StorageKey =
    StorageKey @@ Algos.hash(Ints.toByteArray(height)).untag(Digest32)
  def headerScoreKey(id: ModifierId): StorageKey =
    StorageKey @@ Algos.hash("score".getBytes(Algos.charset) ++ id).untag(Digest32)
  def headerHeightKey(id: ModifierId): StorageKey =
    StorageKey @@ Algos.hash("height".getBytes(Algos.charset) ++ id).untag(Digest32)
  def validityKey(id: Array[Byte]): StorageKey =
    StorageKey @@ Algos.hash("validity".getBytes(Algos.charset) ++ id).untag(Digest32)
} 
Example 119
Source File: HistoryStorage.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.history.storage

import cats.syntax.option._
import com.typesafe.scalalogging.StrictLogging
import encry.storage.{EncryStorage, VersionalStorage}
import encry.storage.VersionalStorage.{StorageKey, StorageValue, StorageVersion}
import encry.storage.iodb.versionalIODB.IODBHistoryWrapper
import encry.storage.levelDb.versionalLevelDB.VLDBWrapper
import io.iohk.iodb.ByteArrayWrapper
import org.encryfoundation.common.modifiers.PersistentModifier
import org.encryfoundation.common.modifiers.history.HistoryModifiersProtoSerializer
import org.encryfoundation.common.utils.TaggedTypes.ModifierId
import scorex.utils.{Random => ScorexRandom}
import scala.util.{Failure, Random, Success}

case class HistoryStorage(override val store: VersionalStorage) extends EncryStorage with StrictLogging {

  def modifierById(id: ModifierId): Option[PersistentModifier] = (store match {
    case iodb: IODBHistoryWrapper => iodb.objectStore.get(ByteArrayWrapper(id)).map(_.data)
    case _: VLDBWrapper           => store.get(StorageKey @@ id.untag(ModifierId))
  })
    .flatMap(res => HistoryModifiersProtoSerializer.fromProto(res) match {
      case Success(b) => b.some
      case Failure(e) => logger.warn(s"Failed to parse block from db: $e"); none
    })

  def containsMod(id: ModifierId): Boolean = store match {
    case iodb: IODBHistoryWrapper => iodb.objectStore.get(ByteArrayWrapper(id)).isDefined
    case _: VLDBWrapper           => store.contains(StorageKey @@ id.untag(ModifierId))
  }

  def modifiersBytesById(id: ModifierId): Option[Array[Byte]] = store match {
    case iodb: IODBHistoryWrapper => iodb.objectStore.get(ByteArrayWrapper(id)).map(_.data.tail)
    case _: VLDBWrapper           => store.get(StorageKey @@ id.untag(ModifierId)).map(_.tail)
  }

  def insertObjects(objectsToInsert: Seq[PersistentModifier]): Unit = store match {
    case iodb: IODBHistoryWrapper =>
      iodb.objectStore.update(
        Random.nextLong(),
        Seq.empty,
        objectsToInsert.map(obj => ByteArrayWrapper(obj.id) ->
          ByteArrayWrapper(HistoryModifiersProtoSerializer.toProto(obj)))
      )
    case _: VLDBWrapper =>
      insert(
        StorageVersion @@ objectsToInsert.head.id.untag(ModifierId),
        objectsToInsert.map(obj =>
          StorageKey @@ obj.id.untag(ModifierId) -> StorageValue @@ HistoryModifiersProtoSerializer.toProto(obj)
        ).toList,
      )
  }

  def bulkInsert(version: Array[Byte],
                 indexesToInsert: Seq[(Array[Byte], Array[Byte])],
                 objectsToInsert: Seq[PersistentModifier]): Unit = store match {
    case _: IODBHistoryWrapper =>
      insertObjects(objectsToInsert)
      insert(
        StorageVersion @@ version,
        indexesToInsert.map { case (key, value) => StorageKey @@ key -> StorageValue @@ value }.toList
      )
    case _: VLDBWrapper =>
      logger.info(s"Inserting2: $objectsToInsert")
      insert(
        StorageVersion @@ version,
        (indexesToInsert.map { case (key, value) =>
          StorageKey @@ key -> StorageValue @@ value
        } ++ objectsToInsert.map { obj =>
          StorageKey @@ obj.id.untag(ModifierId) -> StorageValue @@ HistoryModifiersProtoSerializer.toProto(obj)
        }).toList
      )
  }

  def removeObjects(ids: Seq[ModifierId]): Unit = store match {
    case iodb: IODBHistoryWrapper =>
      iodb.objectStore.update(Random.nextLong(), ids.map(ByteArrayWrapper.apply), Seq.empty)
    case _: VLDBWrapper =>
      store.insert(
        StorageVersion @@ ScorexRandom.randomBytes(),
        toInsert = List.empty,
        ids.map(elem => StorageKey @@ elem.untag(ModifierId)).toList
      )
  }
} 
Example 120
Source File: AccStorage.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.api.http

import java.io.File
import cats.syntax.either._
import com.typesafe.scalalogging.StrictLogging
import encry.settings.EncryAppSettings
import encry.storage.VersionalStorage.StorageKey
import encry.storage.levelDb.versionalLevelDB.LevelDbFactory
import org.encryfoundation.common.utils.Algos
import org.iq80.leveldb.{DB, Options}
import scorex.utils.Random
import supertagged.TaggedType

trait AccStorage extends StrictLogging with AutoCloseable {

  val storage: DB

  val verifyPassword: String => Boolean = pass => {
    val salt = storage.get(AccStorage.SaltKey)
    val passHash = storage.get(AccStorage.PasswordHashKey)
    Algos.hash(pass.getBytes() ++ salt) sameElements passHash
  }

  def setPassword(pass: String): Either[Throwable, Unit] = {
    val batch = storage.createWriteBatch()
    val salt = Random.randomBytes()
    try {
      batch.put(AccStorage.PasswordHashKey, Algos.hash(pass.getBytes() ++ salt))
      batch.put(AccStorage.SaltKey, salt)
      storage.write(batch).asRight[Throwable]
    } catch {
      case err: Throwable => err.asLeft[Unit]
    }
    finally {
      batch.close()
    }
  }

  override def close(): Unit = storage.close()

}

object AccStorage extends StrictLogging {

  object PasswordHash extends TaggedType[Array[Byte]]
  object PasswordSalt extends TaggedType[Array[Byte]]

  type PasswordHash = PasswordHash.Type
  type PasswordSalt = PasswordSalt.Type

  val PasswordHashKey: StorageKey = StorageKey @@ Algos.hash("Password_Key")
  val SaltKey: StorageKey = StorageKey @@ Algos.hash("Salt_Key")

  def getDirStorage(settings: EncryAppSettings): File = new File(s"${settings.directory}/userKeys")

  def init(settings: EncryAppSettings): AccStorage = new AccStorage {
    override val storage: DB = LevelDbFactory.factory.open(getDirStorage(settings), new Options)
  }

} 
Example 121
Source File: NetworkTime.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.utils

import java.net.InetAddress

import com.typesafe.scalalogging.StrictLogging
import encry.utils.NetworkTime.Time
import org.apache.commons.net.ntp.{NTPUDPClient, TimeInfo}

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.Future
import scala.util.Left
import scala.util.control.NonFatal

object NetworkTime {
  def localWithOffset(offset: Long): Long = System.currentTimeMillis() + offset

  type Offset = Long
  type Time = Long
}

protected case class NetworkTime(offset: NetworkTime.Offset, lastUpdate: NetworkTime.Time)

case class NetworkTimeProviderSettings(server: String, updateEvery: FiniteDuration, timeout: FiniteDuration)

class NetworkTimeProvider(ntpSettings: NetworkTimeProviderSettings) extends StrictLogging {

  private var state: State = Right(NetworkTime(0L, 0L))
  private var delta: Time = 0L

  private type State = Either[(NetworkTime, Future[NetworkTime]), NetworkTime]

  private def updateOffSet(): Option[NetworkTime.Offset] = {
    val client: NTPUDPClient = new NTPUDPClient()
    client.setDefaultTimeout(ntpSettings.timeout.toMillis.toInt)
    try {
      client.open()
      val info: TimeInfo = client.getTime(InetAddress.getByName(ntpSettings.server))
      info.computeDetails()
      Option(info.getOffset)
    } catch {
      case t: Throwable => None
    } finally {
      client.close()
    }
  }

  private def timeAndState(currentState: State): Future[(NetworkTime.Time, State)] =
    currentState match {
      case Right(nt) =>
        val time: Long = NetworkTime.localWithOffset(nt.offset)
        val state: Either[(NetworkTime, Future[NetworkTime]), NetworkTime] =
          if (time > nt.lastUpdate + ntpSettings.updateEvery.toMillis) {
            Left(nt -> Future(updateOffSet()).map { mbOffset =>
              logger.info("New offset adjusted: " + mbOffset)
              val offset = mbOffset.getOrElse(nt.offset)
              NetworkTime(offset, NetworkTime.localWithOffset(offset))
            })
          } else Right(nt)
        Future.successful((time, state))
      case Left((nt, networkTimeFuture)) =>
        networkTimeFuture
          .map(networkTime => NetworkTime.localWithOffset(networkTime.offset) -> Right(networkTime))
          .recover {
            case NonFatal(th) =>
              logger.warn(s"Failed to evaluate networkTimeFuture $th")
              NetworkTime.localWithOffset(nt.offset) -> Left(nt -> networkTimeFuture)
          }
    }

  def estimatedTime: Time = state match {
    case Right(nt) if NetworkTime.localWithOffset(nt.offset) <= nt.lastUpdate + ntpSettings.updateEvery.toMillis =>
      NetworkTime.localWithOffset(nt.offset)
    case _ => System.currentTimeMillis() + delta
  }

  def time(): Future[NetworkTime.Time] =
    timeAndState(state)
      .map { case (timeFutureResult, stateFutureResult) =>
        state = stateFutureResult
        delta = timeFutureResult - System.currentTimeMillis()
        timeFutureResult
      }

} 
Example 122
Source File: Zombie.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.stats

import akka.actor.{Actor, DeadLetter, UnhandledMessage}
import com.typesafe.scalalogging.StrictLogging

class Zombie extends Actor with StrictLogging {

  override def preStart(): Unit = {
    context.system.eventStream.subscribe(self, classOf[DeadLetter])
    context.system.eventStream.subscribe(self, classOf[UnhandledMessage])
  }

  override def receive: Receive = {
    case deadMessage: DeadLetter => logger.info(s"Dead letter: ${deadMessage.toString}." +
      s"From: ${deadMessage.sender}. To ${deadMessage.recipient}")
    case unhandled: UnhandledMessage => logger.info(s"Unhandled letter: ${unhandled.toString}. " +
      s"From: ${unhandled.sender}. To ${unhandled.recipient}")
  }
} 
Example 123
Source File: WalletVersionalLevelDB.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.storage.levelDb.versionalLevelDB

import cats.instances.all._
import cats.syntax.semigroup._
import com.google.common.primitives.Longs
import com.typesafe.scalalogging.StrictLogging
import encry.settings.LevelDBSettings
import encry.storage.levelDb.versionalLevelDB.VersionalLevelDBCompanion._
import encry.utils.{BalanceCalculator, ByteStr}
import org.encryfoundation.common.modifiers.state.StateModifierSerializer
import org.encryfoundation.common.modifiers.state.box.Box.Amount
import org.encryfoundation.common.modifiers.state.box.EncryBaseBox
import org.encryfoundation.common.modifiers.state.box.TokenIssuingBox.TokenId
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.common.utils.TaggedTypes.{ADKey, ModifierId}
import org.iq80.leveldb.DB
import scorex.crypto.hash.Digest32
import scala.util.Success

case class WalletVersionalLevelDB(db: DB, settings: LevelDBSettings) extends StrictLogging with AutoCloseable {

  import WalletVersionalLevelDBCompanion._

  val levelDb: VersionalLevelDB = VersionalLevelDB(db, settings)

  //todo: optimize this
  def getAllBoxes(maxQty: Int = -1): Seq[EncryBaseBox] = levelDb.getAll(maxQty)
    .filterNot(_._1 sameElements BALANCE_KEY)
    .map { case (key, bytes) => StateModifierSerializer.parseBytes(bytes, key.head) }
    .collect { case Success(box) => box }

  def getBoxById(id: ADKey): Option[EncryBaseBox] = levelDb.get(VersionalLevelDbKey @@ id.untag(ADKey))
    .flatMap(wrappedBx => StateModifierSerializer.parseBytes(wrappedBx, id.head).toOption)

  def getTokenBalanceById(id: TokenId): Option[Amount] = getBalances
    .find(_._1._2 == Algos.encode(id))
    .map(_._2)

  def containsBox(id: ADKey): Boolean = getBoxById(id).isDefined

  def rollback(modId: ModifierId): Unit = levelDb.rollbackTo(LevelDBVersion @@ modId.untag(ModifierId))

  def updateWallet(modifierId: ModifierId, newBxs: Seq[EncryBaseBox], spentBxs: Seq[EncryBaseBox],
                   intrinsicTokenId: ADKey): Unit = {
    val bxsToInsert: Seq[EncryBaseBox] = newBxs.filter(bx => !spentBxs.contains(bx))
    val newBalances: Map[(String, String), Amount] = {
      val toRemoveFromBalance = BalanceCalculator.balanceSheet(spentBxs, intrinsicTokenId)
        .map { case ((hash, key), value) => (hash, ByteStr(key)) -> value * -1 }
      val toAddToBalance = BalanceCalculator.balanceSheet(newBxs, intrinsicTokenId)
        .map { case ((hash, key), value) => (hash, ByteStr(key)) -> value }
      val prevBalance = getBalances.map { case ((hash, id), value) => (hash, ByteStr(Algos.decode(id).get)) -> value }
      (toAddToBalance |+| toRemoveFromBalance |+| prevBalance).map { case ((hash, tokenId), value) => (hash, tokenId.toString) -> value }
    }
    val newBalanceKeyValue = BALANCE_KEY -> VersionalLevelDbValue @@
      newBalances.foldLeft(Array.emptyByteArray) { case (acc, ((hash, tokenId), balance)) =>
        acc ++ Algos.decode(hash).get ++ Algos.decode(tokenId).get ++ Longs.toByteArray(balance)
      }
    levelDb.insert(LevelDbDiff(LevelDBVersion @@ modifierId.untag(ModifierId),
      newBalanceKeyValue :: bxsToInsert.map(bx => (VersionalLevelDbKey @@ bx.id.untag(ADKey),
        VersionalLevelDbValue @@ bx.bytes)).toList,
      spentBxs.map(elem => VersionalLevelDbKey @@ elem.id.untag(ADKey)))
    )
  }

  def getBalances: Map[(String, String), Amount] =
    levelDb.get(BALANCE_KEY)
      .map(_.sliding(72, 72)
        .map(ch => (Algos.encode(ch.take(32)), Algos.encode(ch.slice(32, 64))) -> Longs.fromByteArray(ch.takeRight(8)))
        .toMap).getOrElse(Map.empty)

  override def close(): Unit = levelDb.close()
}

object WalletVersionalLevelDBCompanion extends StrictLogging {

  val BALANCE_KEY: VersionalLevelDbKey =
    VersionalLevelDbKey @@ Algos.hash("BALANCE_KEY").untag(Digest32)

  val INIT_MAP: Map[VersionalLevelDbKey, VersionalLevelDbValue] = Map(
    BALANCE_KEY -> VersionalLevelDbValue @@ Array.emptyByteArray
  )

  def apply(levelDb: DB, settings: LevelDBSettings): WalletVersionalLevelDB = {
    val db = WalletVersionalLevelDB(levelDb, settings)
    db.levelDb.recoverOrInit(INIT_MAP)
    db
  }
} 
Example 124
Source File: LevelDbFactory.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.storage.levelDb.versionalLevelDB

import com.typesafe.scalalogging.StrictLogging
import org.iq80.leveldb.DBFactory

import scala.util.Try

object LevelDbFactory extends StrictLogging {
  private val nativeFactory = "org.fusesource.leveldbjni.JniDBFactory"
  private val javaFactory   = "org.iq80.leveldb.impl.Iq80DBFactory"

  lazy val factory: DBFactory = {
    val pairs = for {
      loader      <- List(ClassLoader.getSystemClassLoader, this.getClass.getClassLoader).view
      factoryName <- List(nativeFactory, javaFactory)
      factory     <- Try(loader.loadClass(factoryName).getConstructor().newInstance().asInstanceOf[DBFactory]).toOption
    } yield (factoryName, factory)

    val (fName, f) = pairs.headOption.getOrElse(throw new RuntimeException(s"Could not load any of the factory classes: $nativeFactory, $javaFactory"))
    if (fName == javaFactory) logger.warn("Using the pure java LevelDB implementation which is still experimental")
    else logger.trace(s"Loaded $fName with $f")
    f
  }
} 
Example 125
Source File: IODBWrapper.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.storage.iodb.versionalIODB

import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.{StorageKey, StorageValue, StorageVersion}
import io.iohk.iodb.Store.{K, V}
import io.iohk.iodb.{ByteArrayWrapper, Store}
import org.encryfoundation.common.utils.Algos

import scala.collection.mutable


case class IODBWrapper(store: Store) extends VersionalStorage with StrictLogging {

  override def get(key: StorageKey): Option[StorageValue] =
    store.get(ByteArrayWrapper(key)).map(StorageValue @@ _.data)

  override def contains(key: StorageKey): Boolean = get(key).isDefined

  override def currentVersion: StorageVersion =
    store.lastVersionID.map(StorageVersion @@ _.data).getOrElse(IODBWrapper.initVer)

  override def versions: List[StorageVersion] =
    store.rollbackVersions().map(StorageVersion @@ _.data).toList

  override def rollbackTo(to: StorageVersion): Unit =
    store.rollback(ByteArrayWrapper(to))

  override def insert(version: StorageVersion,
                      toInsert: List[(StorageKey, StorageValue)],
                      toDelete: List[StorageKey] = List.empty): Unit = {
    logger.info(s"Update to version: ${Algos.encode(version)}")
    store.update(
      ByteArrayWrapper(version),
      toDelete.map(ByteArrayWrapper.apply),
      toInsert.map{case (keyToAdd, valToAdd) => ByteArrayWrapper(keyToAdd) -> ByteArrayWrapper(valToAdd)}
    )
  }

  //always return all elements
  override def getAll(maxQty: Int = -1): Iterator[(StorageKey, StorageValue)] =
    store.getAll().map{case (key, value) => StorageKey @@ key.data -> StorageValue @@ value.data}

  override def getAllKeys(maxQty: Int = -1): Iterator[StorageKey] =
    store.getAll().map{case (key, _) => StorageKey @@ key.data}

  override def close(): Unit = store.close()
}

object IODBWrapper {

  val initVer: StorageVersion = StorageVersion @@ Array.fill(33)(0: Byte)
} 
Example 126
Source File: EncryPropositionFunctions.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.modifiers.state

import com.typesafe.scalalogging.StrictLogging
import encry.utils.RegularContractEvaluator
import io.iohk.iodb.ByteArrayWrapper
import org.encryfoundation.common.modifiers.mempool.transaction.{Proof, PubKeyLockedContract, RegularContract}
import org.encryfoundation.common.modifiers.state.box.EncryProposition
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.prismlang.codec.PCodec
import org.encryfoundation.prismlang.compiler.CompiledContract
import org.encryfoundation.prismlang.core.Ast.Expr
import org.encryfoundation.prismlang.core.wrapped.PValue
import org.encryfoundation.prismlang.evaluator.Evaluator
import scorex.crypto.encode.Base16
import scorex.crypto.hash.Blake2b256
import scorex.crypto.signatures.PublicKey
import scorex.utils.Random

object EncryPropositionFunctions extends StrictLogging {

  val contract = PubKeyLockedContract(PublicKey @@ Random.randomBytes())

  def pubKeyContractBytes(key: Array[Byte]): Array[Byte] = {
    val pubKeyBytes = PCodec.exprCodec.encode(Expr.Base16Str(Base16.encode(key))).require.toByteArray
    contract.contract.bytes.dropRight(pubKeyBytes.length + 1) ++ pubKeyBytes :+ (3: Byte)
  }

  def canUnlock(proposition: EncryProposition, ctx: Context, contract: Either[CompiledContract, RegularContract], proofs: Seq[Proof]): Boolean =
    contract.fold (
      cc => if (sameHash(proposition.contractHash, cc.hash)) {
        val env: List[(Option[String], PValue)] =
          if (cc.args.isEmpty) List.empty
          else List((None, ctx.transaction.asVal), (None, ctx.state.asVal), (None, ctx.box.asVal)) ++
            proofs.map(proof => (proof.tagOpt, proof.value))
        val args: List[(String, PValue)] = cc.args.map { case (name, tpe) =>
          env.find(_._1.contains(name))
            .orElse(env.find(e => e._2.tpe == tpe || tpe.isSubtypeOf(e._2.tpe)))
            .map(elt => name -> elt._2)
            .getOrElse(throw new Exception("Not enough arguments for contact")) }
        Evaluator.initializedWith(args).eval[Boolean](cc.script)
      } else false,
      rc => {
        val contractHash = rc match {
          case PubKeyLockedContract(pubKey) => Blake2b256.hash(pubKeyContractBytes(pubKey))
          case anotherContract => anotherContract.contract.hash
        }
        if (sameHash(proposition.contractHash, contractHash)) RegularContractEvaluator.eval(rc, ctx, proofs)
        else false
      }
  )

  def sameHash(h1: Array[Byte], h2: Array[Byte]): Boolean = ByteArrayWrapper(h1) == ByteArrayWrapper(h2)
} 
Example 127
Source File: MemoryPoolTests.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.mempool

import akka.actor.ActorSystem
import akka.testkit.{ TestActorRef, TestProbe }
import com.typesafe.scalalogging.StrictLogging
import encry.modifiers.InstanceFactory
import encry.settings.{ EncryAppSettings, TestNetSettings }
import encry.utils.NetworkTimeProvider
import encry.view.mempool.MemoryPool.{ NewTransaction, TransactionsForMiner }
import org.scalatest.{ BeforeAndAfterAll, Matchers, OneInstancePerTest, WordSpecLike }

import scala.concurrent.duration._

class MemoryPoolTests
    extends WordSpecLike
    with Matchers
    with InstanceFactory
    with BeforeAndAfterAll
    with OneInstancePerTest
    with TestNetSettings
    with StrictLogging {

  implicit val system: ActorSystem = ActorSystem()

  override def afterAll(): Unit = system.terminate()

  val timeProvider: NetworkTimeProvider = new NetworkTimeProvider(testNetSettings.ntp)

  "MemoryPool" should {
    "add new unique transactions" in {
      val mempool                = MemoryPoolStorage.empty(testNetSettings, timeProvider)
      val transactions           = genValidPaymentTxs(10)
      val (newMempool, validTxs) = mempool.validateTransactions(transactions)
      newMempool.size shouldBe 10
      validTxs.map(_.encodedId).forall(transactions.map(_.encodedId).contains) shouldBe true
    }
    "reject not unique transactions" in {
      val mempool                          = MemoryPoolStorage.empty(testNetSettings, timeProvider)
      val transactions                     = genValidPaymentTxs(10)
      val (newMempool, validTxs)           = mempool.validateTransactions(transactions)
      val (newMempoolAgain, validTxsAgain) = newMempool.validateTransactions(validTxs)
      newMempoolAgain.size shouldBe 10
      validTxsAgain.size shouldBe 0
    }
    "mempoolMaxCapacity works correct" in {
      val mempool                = MemoryPoolStorage.empty(testNetSettings, timeProvider)
      val transactions           = genValidPaymentTxs(11)
      val (newMempool, validTxs) = mempool.validateTransactions(transactions)
      newMempool.size shouldBe 10
      validTxs.size shouldBe 10
    }
    "getTransactionsForMiner works fine" in {
      val mempool         = MemoryPoolStorage.empty(testNetSettings, timeProvider)
      val transactions    = (0 until 10).map(k => coinbaseAt(k))
      val (newMempool, _) = mempool.validateTransactions(transactions)
      val (uPool, txs)    = newMempool.getTransactionsForMiner
      uPool.size shouldBe 0
      txs.map(_.encodedId).forall(transactions.map(_.encodedId).contains) shouldBe true
      transactions.map(_.encodedId).forall(txs.map(_.encodedId).contains) shouldBe true
    }
  }
  "Mempool actor" should {
    "send transactions to miner" in {
      val miner1 = TestProbe()
      val mempool1: TestActorRef[MemoryPool] =
        TestActorRef[MemoryPool](MemoryPool.props(testNetSettings, timeProvider, miner1.ref, Some(TestProbe().ref)))
      val transactions1 = (0 until 4).map { k =>
        val a = coinbaseAt(k)
        a
      }
      transactions1.foreach(mempool1 ! NewTransaction(_))
      mempool1.underlyingActor.memoryPool.size shouldBe 4
      logger.info(s"generated: ${transactions1.map(_.encodedId)}")
      miner1.expectMsg(20.seconds, TransactionsForMiner(transactions1))
    }
  }
} 
Example 128
Source File: LevelDbUnitsGenerator.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.utils.levelDBUtils

import com.typesafe.scalalogging.StrictLogging
import encry.storage.levelDb.versionalLevelDB.LevelDbDiff
import encry.storage.levelDb.versionalLevelDB.VersionalLevelDBCompanion.{LevelDBVersion, VersionalLevelDbKey, VersionalLevelDbValue}
import io.iohk.iodb.ByteArrayWrapper
import scorex.utils.Random

import scala.util.{Random => ScalaRandom}

trait LevelDbUnitsGenerator extends StrictLogging {

  val defaultKeySize: Int = 32
  val defaultValueSize: Int = 256

  def generateRandomKey(keySize: Int = defaultKeySize): VersionalLevelDbKey =
    VersionalLevelDbKey @@ Random.randomBytes(keySize)

  def generateRandomValue(valueSize: Int = defaultValueSize): VersionalLevelDbValue =
    VersionalLevelDbValue @@ Random.randomBytes(valueSize)

  def genRandomInsertValue(keySize: Int = defaultKeySize,
                           valueSize: Int = defaultValueSize): (VersionalLevelDbKey, VersionalLevelDbValue) =
    (generateRandomKey(keySize), generateRandomValue(valueSize))

  def generateRandomLevelDbElemsWithoutDeletions(qty: Int, qtyOfElemsToInsert: Int): List[LevelDbDiff] =
    (0 until qty).foldLeft(List.empty[LevelDbDiff]) {
      case (acc, i) =>
        LevelDbDiff(
          LevelDBVersion @@ Random.randomBytes(),
          List((0 until qtyOfElemsToInsert).map(_ => genRandomInsertValue()): _*)
        ) :: acc
    }

  
  def generateRandomLevelDbElemsWithLinkedDeletions(qty: Int, qtyOfElemsToInsert: Int): Seq[LevelDbDiff] =
    (0 until qty).foldLeft(Seq.empty[LevelDbDiff]) {
      case (acc, _) =>
        acc :+ LevelDbDiff(
          LevelDBVersion @@ Random.randomBytes(),
          List((0 until qtyOfElemsToInsert).map(_ => genRandomInsertValue()): _*),
          acc.lastOption.map(_.elemsToInsert.map(_._1)).getOrElse(Seq.empty[VersionalLevelDbKey])
        )
    }

  def generateRandomLevelDbElemsWithSameKeys(qty: Int, qtyOfElemsToInsert: Int): Seq[LevelDbDiff] = {
    (0 until qty).foldLeft(Seq.empty[LevelDbDiff]) {
      case (acc, _) =>
        acc :+ LevelDbDiff(
          LevelDBVersion @@ Random.randomBytes(),
          acc.lastOption
            .map(_.elemsToInsert.map(elem => (elem._1, generateRandomValue())))
            .getOrElse(
              List((0 until qtyOfElemsToInsert).map(_ => genRandomInsertValue()): _*)
            ),
          Seq.empty[VersionalLevelDbKey]
        )
    }
  }
} 
Example 129
Source File: DataTransactionTest.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.it.transactions

import TransactionGenerator.CreateTransaction
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import encry.it.configs.Configs
import encry.it.docker.NodesFromDocker
import encry.it.util.KeyHelper._
import org.encryfoundation.common.crypto.PrivateKey25519
import org.encryfoundation.common.modifiers.history.Block
import org.encryfoundation.common.modifiers.mempool.transaction.{PubKeyLockedContract, Transaction}
import org.encryfoundation.common.modifiers.state.box.{AssetBox, EncryBaseBox}
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{AsyncFunSuite, Matchers}
import scorex.utils.Random
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._

class DataTransactionTest extends AsyncFunSuite
  with Matchers
  with ScalaFutures
  with StrictLogging
  with NodesFromDocker {

  override protected def nodeConfigs: Seq[Config] = Seq(Configs.mining(true)
    .withFallback(Configs.offlineGeneration(true))
    .withFallback(Configs.nodeName("node1")))

  test("Create and send data transaction. Check chain for it.") {

    val firstHeightToWait: Int = 5
    val secondHeightToWait: Int = 8
    val mnemonicKey: String = "index another island accuse valid aerobic little absurd bunker keep insect scissors"
    val privKey: PrivateKey25519 = createPrivKey(Some(mnemonicKey))
    val waitTime: FiniteDuration = 30.minutes
    val fee: Long = scala.util.Random.nextInt(500)

    Await.result(dockerNodes().head.waitForHeadersHeight(firstHeightToWait), waitTime)

    val boxes: Seq[EncryBaseBox] = Await.result(dockerNodes().head.outputs, waitTime)
    val oneBox: AssetBox = boxes.collect { case ab: AssetBox => ab }.head
    val transaction: Transaction = CreateTransaction.dataTransactionScratch(
      privKey,
      fee,
      System.currentTimeMillis(),
      IndexedSeq(oneBox).map(_ -> None),
      PubKeyLockedContract(privKey.publicImage.pubKeyBytes).contract,
      Random.randomBytes(32)
    )

    Await.result(dockerNodes().head.sendTransaction(transaction), waitTime)
    Await.result(dockerNodes().head.waitForHeadersHeight(secondHeightToWait), waitTime)

    val headersAtHeight: List[String] = (firstHeightToWait + 1 to secondHeightToWait)
      .foldLeft(List[String]()) { case (list, blockHeight) =>
        val headers: Future[List[String]] = dockerNodes().head.getHeadersIdAtHeight(blockHeight)
        val result: List[String] = Await.result(headers, waitTime)
        list ::: result
      }

    Await.result(dockerNodes().head.getBlock(headersAtHeight.head), waitTime)

    val lastBlocks: Future[Seq[Block]] = Future.sequence(headersAtHeight.map { h => dockerNodes().head.getBlock(h) })

    lastBlocks.map { blocks =>
      val txsNum: Int = blocks.map(_.payload.txs.size).sum
      docker.close()
      val transactionFromChain: Transaction = blocks.flatMap(_.payload.txs.init).head
      transactionFromChain.id shouldEqual transaction.id
      true shouldEqual (txsNum > secondHeightToWait - firstHeightToWait)
      txsNum shouldEqual (secondHeightToWait - firstHeightToWait + 1)
    }
  }
} 
Example 130
Source File: ProcessingTransferTransactionWithEncryCoinsTest.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.it.transactions

import TransactionGenerator.CreateTransaction
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import encry.consensus.EncrySupplyController
import encry.it.configs.Configs
import encry.it.docker.NodesFromDocker
import encry.it.util.KeyHelper._
import encry.settings.Settings
import org.encryfoundation.common.crypto.{PrivateKey25519, PublicKey25519}
import org.encryfoundation.common.modifiers.history.Block
import org.encryfoundation.common.modifiers.mempool.transaction.EncryAddress.Address
import org.encryfoundation.common.modifiers.mempool.transaction.Transaction
import org.encryfoundation.common.modifiers.state.box.{AssetBox, EncryBaseBox}
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.common.utils.TaggedTypes.Height
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{AsyncFunSuite, Matchers}
import scorex.crypto.signatures.Curve25519
import scorex.utils.Random

import scala.concurrent.duration._
import scala.concurrent.{Await, Future}

class ProcessingTransferTransactionWithEncryCoinsTest extends AsyncFunSuite
  with Matchers
  with ScalaFutures
  with StrictLogging
  with NodesFromDocker
  with Settings {

  override protected def nodeConfigs: Seq[Config] = Seq(Configs.mining(true)
    .withFallback(Configs.offlineGeneration(true))
    .withFallback(Configs.nodeName("node1")))

  test("Create and send monetary transaction. Check balance.") {

    val amount: Int = scala.util.Random.nextInt(2000)
    val fee: Long = scala.util.Random.nextInt(500)
    val firstHeightToWait: Int = 5
    val secondHeightToWait: Int = 8
    val mnemonicKey: String = "index another island accuse valid aerobic little absurd bunker keep insect scissors"
    val privKey: PrivateKey25519 = createPrivKey(Some(mnemonicKey))
    val recipientAddress: Address = PublicKey25519(Curve25519.createKeyPair(Random.randomBytes())._2).address.address
    val waitTime: FiniteDuration = 30.minutes

    val supplyAtHeight: Long = (0 to secondHeightToWait).foldLeft(0: Long) {
      case (supply, i) => supply + EncrySupplyController.supplyAt(Height @@ i, settings.constants)
    }

    Await.result(dockerNodes().head.waitForHeadersHeight(firstHeightToWait), waitTime)

    val boxes: Seq[EncryBaseBox] = Await.result(dockerNodes().head.outputs, waitTime)
    val oneBox: AssetBox = boxes.collect { case ab: AssetBox => ab }.head
    val transaction: Transaction = CreateTransaction.defaultPaymentTransaction(
      privKey,
      fee,
      System.currentTimeMillis(),
      IndexedSeq(oneBox).map(_ -> None),
      recipientAddress,
      amount
    )

    Await.result(dockerNodes().head.sendTransaction(transaction), waitTime)
    Await.result(dockerNodes().head.waitForHeadersHeight(secondHeightToWait), waitTime)

    val checkBalance: Boolean = Await.result(dockerNodes().head.balances, waitTime)
      .find(_._1 == Algos.encode(settings.constants.IntrinsicTokenId))
      .map(_._2 == supplyAtHeight - amount)
      .get

    val headersAtHeight: List[String] = (firstHeightToWait + 1 to secondHeightToWait)
      .foldLeft(List[String]()) { case (list, blockHeight) =>
        val headers: Future[List[String]] = dockerNodes().head.getHeadersIdAtHeight(blockHeight)
        val result: List[String] = Await.result(headers, waitTime)
        list ::: result
      }

    Await.result(dockerNodes().head.getBlock(headersAtHeight.head), waitTime)

    val lastBlocks: Future[Seq[Block]] = Future.sequence(headersAtHeight.map { h => dockerNodes().head.getBlock(h) })

    lastBlocks.map { blocks =>
      val txsNum: Int = blocks.map(_.payload.txs.size).sum
      docker.close()
      val transactionFromChain: Transaction = blocks.flatMap(_.payload.txs.init).head
      transactionFromChain.id shouldEqual transaction.id
      true shouldEqual (txsNum > secondHeightToWait - firstHeightToWait)
      txsNum shouldEqual (secondHeightToWait - firstHeightToWait + 1)
      checkBalance shouldBe true
    }
  }
} 
Example 131
Source File: StorageNodeActor.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.actors

import akka.actor.{Actor, ActorRef, Props, RootActorPath, Terminated}
import akka.cluster.ClusterEvent.{CurrentClusterState, MemberUp}
import akka.cluster.{Cluster, Member, MemberStatus}
import com.typesafe.scalalogging.StrictLogging
import justin.db.actors.protocol.{RegisterNode, _}
import justin.db.cluster.ClusterMembers
import justin.db.cluster.datacenter.Datacenter
import justin.db.consistenthashing.{NodeId, Ring}
import justin.db.replica._
import justin.db.replica.read.{ReplicaLocalReader, ReplicaReadCoordinator, ReplicaRemoteReader}
import justin.db.replica.write.{ReplicaLocalWriter, ReplicaRemoteWriter, ReplicaWriteCoordinator}
import justin.db.storage.PluggableStorageProtocol

import scala.concurrent.ExecutionContext

class StorageNodeActor(nodeId: NodeId, datacenter: Datacenter, storage: PluggableStorageProtocol, ring: Ring, n: N) extends Actor with StrictLogging {

  private[this] implicit val ec: ExecutionContext = context.dispatcher
  private[this] val cluster = Cluster(context.system)

  private[this] var clusterMembers   = ClusterMembers.empty
  private[this] val readCoordinator  = new ReplicaReadCoordinator(nodeId, ring, n, new ReplicaLocalReader(storage), new ReplicaRemoteReader)
  private[this] val writeCoordinator = new ReplicaWriteCoordinator(nodeId, ring, n, new ReplicaLocalWriter(storage), new ReplicaRemoteWriter)

  private[this] val coordinatorRouter = context.actorOf(
    props = RoundRobinCoordinatorRouter.props(readCoordinator, writeCoordinator),
    name  = RoundRobinCoordinatorRouter.routerName
  )

  private[this] val name = self.path.name

  override def preStart(): Unit = cluster.subscribe(this.self, classOf[MemberUp])
  override def postStop(): Unit = cluster.unsubscribe(this.self)

  def receive: Receive = {
    receiveDataPF orElse receiveClusterDataPF orElse receiveRegisterNodePR orElse notHandledPF
  }

  private[this] def receiveDataPF: Receive = {
    case readReq: StorageNodeReadRequest              =>
      coordinatorRouter ! ReadData(sender(), clusterMembers, readReq)
    case writeLocalDataReq: StorageNodeWriteDataLocal =>
      coordinatorRouter ! WriteData(sender(), clusterMembers, writeLocalDataReq)
    case writeClientReplicaReq: Internal.WriteReplica =>
      coordinatorRouter ! WriteData(sender(), clusterMembers, writeClientReplicaReq)
  }

  private[this] def receiveClusterDataPF: Receive = {
    case "members"                  => sender() ! clusterMembers
    case MemberUp(member)           => register(nodeId, ring, member)
    case state: CurrentClusterState => state.members.filter(_.status == MemberStatus.Up).foreach(member => register(nodeId, ring, member))
    case Terminated(actorRef)       => clusterMembers = clusterMembers.removeByRef(StorageNodeActorRef(actorRef))
  }

  private[this] def receiveRegisterNodePR: Receive = {
    case RegisterNode(senderNodeId) if clusterMembers.notContains(senderNodeId) =>
      val senderRef = sender()
      context.watch(senderRef)
      clusterMembers = clusterMembers.add(senderNodeId, StorageNodeActorRef(senderRef))
      senderRef ! RegisterNode(nodeId)
      logger.info(s"Actor[$name]: Successfully registered node [id-${senderNodeId.id}]")
    case RegisterNode(senderNodeId) =>
      logger.info(s"Actor[$name]: Node [id-${senderNodeId.id}] is already registered")
  }

  private[this] def register(nodeId: NodeId, ring: Ring, member: Member) = {
    (member.hasRole(StorageNodeActor.role), datacenter.name == member.dataCenter) match {
      case (true, true) => register()
      case (_,   false) => logger.info(s"Actor[$name]: $member doesn't belong to datacenter [${datacenter.name}]")
      case (false,   _) => logger.info(s"Actor[$name]: $member doesn't have [${StorageNodeActor.role}] role (it has roles ${member.roles}")
    }

    def register() = for {
      ringNodeId    <- ring.nodesId
      nodeName       = StorageNodeActor.name(ringNodeId, Datacenter(member.dataCenter))
      nodeRef        = context.actorSelection(RootActorPath(member.address) / "user" / nodeName)
    } yield nodeRef ! RegisterNode(nodeId)
  }

  private[this] def notHandledPF: Receive = {
    case t => logger.warn(s"Actor[$name]: Not handled message [$t]")
  }
}

object StorageNodeActor {
  def role: String = "storagenode"
  def name(nodeId: NodeId, datacenter: Datacenter): String = s"${datacenter.name}-id-${nodeId.id}"
  def props(nodeId: NodeId, datacenter: Datacenter, storage: PluggableStorageProtocol, ring: Ring, n: N): Props = {
    Props(new StorageNodeActor(nodeId, datacenter, storage, ring, n))
  }
}

case class StorageNodeActorRef(ref: ActorRef) extends AnyVal 
Example 132
Source File: SerializerInit.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import com.esotericsoftware.kryo.Kryo
import com.typesafe.scalalogging.StrictLogging

class SerializerInit extends StrictLogging {

  def customize(kryo: Kryo): Unit = {
    logger.info("Initialized Kryo")

    // cluster
    kryo.register(classOf[justin.db.actors.protocol.RegisterNode], RegisterNodeSerializer, 50)

    // write -- request
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeWriteDataLocal], StorageNodeWriteDataLocalSerializer, 60)

    // write -- responses
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeFailedWrite],     StorageNodeWriteResponseSerializer, 70)
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeSuccessfulWrite], StorageNodeWriteResponseSerializer, 71)
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeConflictedWrite], StorageNodeWriteResponseSerializer, 72)

    // read - request
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeLocalRead], StorageNodeLocalReadSerializer, 80)

    // read - responses
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeFoundRead],      StorageNodeReadResponseSerializer, 90)
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeConflictedRead], StorageNodeReadResponseSerializer, 91)
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeNotFoundRead],   StorageNodeReadResponseSerializer, 92)
    kryo.register(classOf[justin.db.actors.protocol.StorageNodeFailedRead],     StorageNodeReadResponseSerializer, 93)

    ()
  }
} 
Example 133
Source File: Main.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db

import akka.actor.ActorSystem
import buildinfo.BuildInfo
import com.typesafe.scalalogging.StrictLogging

// $COVERAGE-OFF$
object Main extends App with StrictLogging {

  logger.info(
    """
      |   ___              _    _        ______ ______
      |  |_  |            | |  (_)       |  _  \| ___ \
      |    | | _   _  ___ | |_  _  _ __  | | | || |_/ /
      |    | || | | |/ __|| __|| || '_ \ | | | || ___ \
      |/\__/ /| |_| |\__ \| |_ | || | | || |/ / | |_/ /
      |\____/  \__,_||___/ \__||_||_| |_||___/  \____/
      |
    """.stripMargin
  )

  val justindbConfig = JustinDBConfig.init
  val actorSystem    = ActorSystem(justindbConfig.system, justindbConfig.config)
  val justindb       = JustinDB.init(justindbConfig)(actorSystem)

  logger.info("Build Info: " + BuildInfo.toString)
}
// $COVERAGE-ON$ 
Example 134
Source File: JustinDB.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db

import akka.actor.ActorSystem
import akka.cluster.Cluster
import akka.cluster.http.management.ClusterHttpManagement
import akka.http.scaladsl.Http
import akka.http.scaladsl.server.Directives._
import akka.stream.{ActorMaterializer, Materializer}
import buildinfo.BuildInfo
import com.typesafe.scalalogging.StrictLogging
import justin.db.actors.{StorageNodeActor, StorageNodeActorRef}
import justin.db.client.ActorRefStorageNodeClient
import justin.db.cluster.datacenter.Datacenter
import justin.db.consistenthashing.{NodeId, Ring}
import justin.db.replica.N
import justin.db.storage.PluggableStorageProtocol
import justin.db.storage.provider.StorageProvider
import justin.httpapi.{BuildInfoRouter, HealthCheckRouter, HttpRouter}

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Promise}
import scala.language.reflectiveCalls

// $COVERAGE-OFF$
final class JustinDB

object JustinDB extends StrictLogging {

  private[this] def validConfiguration(justinDBConfig: JustinDBConfig): Unit = {
    require(justinDBConfig.replication.N > 0, "replication N factor can't be smaller or equal 0")
    require(justinDBConfig.ring.`members-count` > 0, "members-counter can't be smaller or equal 0")
    require(justinDBConfig.ring.partitions > 0, "ring partitions can't be smaller or equal 0")
    require(justinDBConfig.ring.partitions >= justinDBConfig.ring.`members-count`, "number of ring partitions can't be smaller than number of members-count")
    require(justinDBConfig.replication.N <= justinDBConfig.ring.`members-count`, "replication N factor can't be bigger than defined members-count number")
  }

  private[this] def initStorage(justinConfig: JustinDBConfig) = {
    val provider = StorageProvider.apply(justinConfig.storage.provider)
    logger.info("Storage provider: " + provider.name)
    provider.init
  }

  def init(justinConfig: JustinDBConfig)(implicit actorSystem: ActorSystem): JustinDB = {
    validConfiguration(justinConfig)

    val processOrchestrator = Promise[JustinDB]

    implicit val executor: ExecutionContext = actorSystem.dispatcher
    implicit val materializer: Materializer = ActorMaterializer()

    val storage: PluggableStorageProtocol = initStorage(justinConfig)

    val cluster = Cluster(actorSystem)

    cluster.registerOnMemberUp {
      // STORAGE ACTOR
      val storageNodeActorRef = StorageNodeActorRef {
        val nodeId     = NodeId(justinConfig.`kubernetes-hostname`.split("-").last.toInt)
        val ring       = Ring(justinConfig.ring.`members-count`, justinConfig.ring.partitions)
        val n          = N(justinConfig.replication.N)
        val datacenter = Datacenter(justinConfig.dc.`self-data-center`)

        actorSystem.actorOf(
          props = StorageNodeActor.props(nodeId, datacenter, storage, ring, n),
          name  = StorageNodeActor.name(nodeId, datacenter)
        )
      }

      // AKKA-MANAGEMENT
      ClusterHttpManagement(cluster).start().map { _ =>
        logger.info("Cluster HTTP-Management is ready!")
      }.recover { case ex => processOrchestrator.failure(ex) }

      // HTTP API
      val routes = logRequestResult(actorSystem.name) {
        new HttpRouter(new ActorRefStorageNodeClient(storageNodeActorRef)).routes ~
          new HealthCheckRouter().routes ~
          new BuildInfoRouter().routes(BuildInfo.toJson)
      }
      Http()
        .bindAndHandle(routes, justinConfig.http.interface, justinConfig.http.port)
        .map { binding => logger.info(s"HTTP server started at ${binding.localAddress}"); processOrchestrator.trySuccess(new JustinDB) }
        .recover { case ex => logger.error("Could not start HTTP server", ex); processOrchestrator.failure(ex) }
    }

    Await.result(processOrchestrator.future, 2.minutes)
  }
}
// $COVERAGE-ON$ 
Example 135
Source File: package.scala    From gatling-amqp-plugin   with Apache License 2.0 5 votes vote down vote up
package ru.tinkoff.gatling.amqp

import com.typesafe.scalalogging.StrictLogging
import ru.tinkoff.gatling.amqp.request.AmqpProtocolMessage

package object action {
  object Around {
    def apply(before: Unit, after: Unit): Around = new Around(() => before, () => after)
  }
  class Around(before: () => Unit, after: () => Unit) {

    def apply(f: => Any): Unit = {
      before()
      f
      after()
    }
  }

  trait AmqpLogging extends StrictLogging {
    def logMessage(text: => String, msg: AmqpProtocolMessage): Unit = {
      logger.debug(text)
      logger.trace(msg.toString)
    }
  }

  sealed trait Dest
  case class DirectDest(exchName: String, rk: String) extends Dest
  case class QueueDest(qName: String)                 extends Dest
} 
Example 136
Source File: ValueStoreSerializationExt.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev.randomprojections.serialization

import java.io.File
import com.stefansavev.core.serialization.Utils
import com.stefansavev.randomprojections.datarepr.dense.store.ValuesStore
import com.stefansavev.randomprojections.serialization.DataFrameViewSerializers._
import com.typesafe.scalalogging.StrictLogging

object ValueStoreSerializationExt {
  val ser = valuesStoreSerializer()

  implicit class ValueStoreSerializerExt(input: ValuesStore) {
    def toFile(file: File): Unit = {
      Utils.toFile(ser, file, input)
    }

    def toFile(fileName: String): Unit = {
      toFile(new File(fileName))
    }

    def toBytes(): Array[Byte] = {
      Utils.toBytes(ser, input)
    }
  }

  implicit class ValueStoreDeserializerExt(t: ValuesStore.type) extends StrictLogging {
    def fromFile(file: File): ValuesStore = {
      if (!file.exists()) {
        throw new IllegalStateException("file does not exist: " + file.getAbsolutePath)
      }
      logger.info("Loading file: " + file.getAbsolutePath)
      val output = Utils.fromFile(ser, file)
      output
    }

    def fromFile(fileName: String): ValuesStore = {
      fromFile(new File(fileName))
    }

    def fromBytes(input: Array[Byte]): ValuesStore = {
      Utils.fromBytes(ser, input)
    }
  }

} 
Example 137
Source File: BucketCollectorImpl.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev.randomprojections.implementation

import java.io._
import com.stefansavev.core.serialization.IntArraySerializer
import com.stefansavev.randomprojections.buffers.IntArrayBuffer
import com.stefansavev.randomprojections.datarepr.dense.PointIndexes
import com.stefansavev.randomprojections.datarepr.dense.store.FixedLengthBuffer
import com.stefansavev.randomprojections.interface.{BucketCollector, Index}
import com.typesafe.scalalogging.StrictLogging

object BucketCollectorImplUtils {
  val partitionFileSuffix = "_partition_"

  def fileName(dirName: String, partitionId: Int): String = {
    (new File(dirName, partitionFileSuffix + partitionId)).getAbsolutePath
  }
}

class BucketCollectorImpl(backingDir: String, totalRows: Int) extends BucketCollector with StrictLogging {
  val pointIdsThreshold = 1 << 20

  var leafId = 0
  var starts = new IntArrayBuffer()
  var pointIds = new IntArrayBuffer()
  var globalIndex = 0
  var numPartitions = 0

  def savePartial(): Unit = {
    val fileName = BucketCollectorImplUtils.fileName(backingDir, numPartitions)
    val outputStream = new BufferedOutputStream(new FileOutputStream(fileName))
    logger.info(s"Writing partial buckets to file $fileName")
    IntArraySerializer.write(outputStream, starts.toArray())
    IntArraySerializer.write(outputStream, pointIds.toArray())
    outputStream.close()
    starts = new IntArrayBuffer()
    pointIds = new IntArrayBuffer()
    numPartitions += 1
  }

  def collectPoints(names: PointIndexes): RandomTreeLeaf = {
    val leaf = RandomTreeLeaf(leafId, names.size)

    starts += globalIndex
    globalIndex += names.size
    pointIds ++= names.indexes
    leafId += 1

    if (pointIds.size > pointIdsThreshold) {
      savePartial()
    }
    leaf
  }

  def getIntArrayAndClear(buffer: IntArrayBuffer): Array[Int] = {
    val result = buffer.toArray
    buffer.clear()
    result
  }

  def build(pointSignatures: PointSignatures, labels: Array[Int]): Index = {
    starts += globalIndex
    savePartial()
    //TODO: now leaving leaf2Points as null, fix it
    //val leaf2Points = new Leaf2Points(getIntArrayAndClear(starts), getIntArrayAndClear(pointIds))
    val index = new IndexImpl(pointSignatures, labels.length, Some((backingDir, leafId + 1, globalIndex, numPartitions)), null, labels)
    index
  }
}

object BucketCollectorImpl {
  def mergeLeafData(backingDir: String, startBufferLen: Int, numPoints: Int, numPartitions: Int): Leaf2Points = {
    val starts = new FixedLengthBuffer[Int](startBufferLen)
    val pointsIds = new FixedLengthBuffer[Int](numPoints)
    for (i <- 0 until numPartitions) {
      val fileName = BucketCollectorImplUtils.fileName(backingDir, i)
      val inputStream = new BufferedInputStream(new FileInputStream(fileName))
      starts ++= IntArraySerializer.read(inputStream)
      pointsIds ++= IntArraySerializer.read(inputStream)
      inputStream.close()
    }
    new Leaf2Points(starts.array, pointsIds.array)
  }
} 
Example 138
Source File: ValuesStoreTest.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev

import java.util.Random

import com.stefansavev.randomprojections.datarepr.dense.store._
import com.typesafe.scalalogging.StrictLogging
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TestSingleByteEncodingSpec extends FlatSpec with Matchers {
  "Error after encoding double to float" should "be small" in {
    val minV = -1.0f
    val maxV = 2.0f
    val rnd = new Random(481861)
    for (i <- 0 until 100) {
      //we encode a float (which is 4 bytes) with a single byte
      //therefore the loss of precision
      val value = rnd.nextFloat() * 3.0f - 1.0f
      val enc = FloatToSingleByteEncoder.encodeValue(minV, maxV, value)
      val dec = FloatToSingleByteEncoder.decodeValue(minV, maxV, enc)
      val error = Math.abs(value - dec)
      error should be < (0.01)
    }
  }
}

@RunWith(classOf[JUnitRunner])
class TestValueStores extends FlatSpec with Matchers {

  case class BuilderTypeWithErrorPredicate(builderType: StoreBuilderType, pred: Double => Boolean)

  "ValueStore" should "return store the data with small error" in {

    val tests = List(
      BuilderTypeWithErrorPredicate(StoreBuilderAsDoubleType, error => (error <= 0.0)),
      BuilderTypeWithErrorPredicate(StoreBuilderAsBytesType, error => (error <= 0.01)),
      BuilderTypeWithErrorPredicate(StoreBuilderAsSingleByteType, error => (error <= 0.01))
    )

    for (test <- tests) {
      testBuilder(test)
    }

    def testBuilder(builderWithPred: BuilderTypeWithErrorPredicate): Unit = {
      val dataGenSettings = RandomBitStrings.RandomBitSettings(
        numGroups = 1000,
        numRowsPerGroup = 2,
        numCols = 256,
        per1sInPrototype = 0.5,
        perNoise = 0.2)

      val debug = false
      val randomBitStringsDataset = RandomBitStrings.genRandomData(58585, dataGenSettings, debug, true)
      val builder = builderWithPred.builderType.getBuilder(randomBitStringsDataset.numCols)

      def addValues(): Unit = {
        var i = 0
        while (i < randomBitStringsDataset.numRows) {
          val values = randomBitStringsDataset.getPointAsDenseVector(i)
          builder.addValues(values)
          i += 1
        }
      }

      addValues()

      val valueStore = builder.build()

      def verifyStoredValues(expected: Array[Double], stored: Array[Double]): Unit = {
        for (i <- 0 until expected.length) {
          val error = Math.abs(expected(i) - stored(i))
          val passed = builderWithPred.pred(error)
          passed should be (true)
        }
      }

      def testValues(): Unit = {
        var i = 0
        while (i < randomBitStringsDataset.numRows) {
          val values = randomBitStringsDataset.getPointAsDenseVector(i)
          val output = Array.ofDim[Double](randomBitStringsDataset.numCols)
          valueStore.fillRow(i, output, true)
          verifyStoredValues(values, output)
          i += 1
        }
      }
      testValues()
    }
  }
}


object Test extends StrictLogging {
  def main(args: Array[String]) {
    logger.info("hello")
  }
} 
Example 139
Source File: TestOnRandomData.scala    From random-projections-at-berlinbuzzwords   with Apache License 2.0 5 votes vote down vote up
package com.stefansavev.fuzzysearchtest

import java.util.Random

import com.stefansavev.randomprojections.actors.Application
import com.stefansavev.randomprojections.implementation._
import com.stefansavev.randomprojections.utils.Utils
import com.stefansavev.similaritysearch.SimilaritySearchEvaluationUtils
import com.stefansavev.similaritysearch.VectorType.StorageSize
import com.stefansavev.similaritysearch.implementation.FuzzySearchIndexBuilderWrapper
import com.typesafe.scalalogging.StrictLogging


object TestOnRandomData extends StrictLogging {
  implicit val _ = logger

  def main(args: Array[String]): Unit = {
    val dataGenSettings = RandomBitStrings.RandomBitSettings(
      numGroups = 100000,
      numRowsPerGroup = 2,
      numCols = 256,
      per1sInPrototype = 0.5,
      perNoise = 0.1)

    val debug = false
    val randomBitStringsDataset = RandomBitStrings.genRandomData(58585, dataGenSettings, debug, true)

    val randomTreeSettings = IndexSettings(
      maxPntsPerBucket = 50,
      numTrees = 50,
      maxDepth = None,
      projectionStrategyBuilder = ProjectionStrategies.splitIntoKRandomProjection(),
      reportingDistanceEvaluator = ReportingDistanceEvaluators.cosineOnOriginalData(),
      randomSeed = 39393
    )

    println("Number of Rows: " + randomBitStringsDataset.numRows)
    val diskLocation = "D:/tmp/randomfile"
    val trees = Utils.timed("Build Index", {
      val wrapper = new FuzzySearchIndexBuilderWrapper(diskLocation, randomBitStringsDataset.numCols, 50, StorageSize.Double)
      var i = 0
      while (i < randomBitStringsDataset.numRows) {
        wrapper.addItem(i.toString, 0, randomBitStringsDataset.getPointAsDenseVector(i))
        i += 1
      }
      wrapper.build()
      //SimilaritySearchIndex.open(diskLocation)
      ()
    }).result

    SimilaritySearchEvaluationUtils.compareWithBruteForce(diskLocation, new Random(481868), 1000, 50)

    

    Application.shutdown()
  }
} 
Example 140
Source File: CIFARTrainer.scala    From cct-nn   with Apache License 2.0 5 votes vote down vote up
package toolkit.neuralnetwork.examples

import cogio.FieldState
import com.typesafe.scalalogging.StrictLogging
import libcog._
import toolkit.neuralnetwork.WeightStore
import toolkit.neuralnetwork.examples.networks.Net


object CIFARTrainer extends App with StrictLogging {
  val batchSize = 100
  val netName = 'SimpleConvNet

  def validate(snapshot: Map[Symbol, FieldState]): (Float, Float) = {
    val cg = new ComputeGraph {
      val net = Net(netName, useRandomData = false, learningEnabled = false, batchSize = batchSize,
        training = false, weights = WeightStore.restoreFromSnapshot(snapshot))

      probe(net.correct)
      probe(net.loss.forward)

      def readLoss(): Float = {
        read(net.loss.forward).asInstanceOf[ScalarFieldReader].read()
      }

      def readCorrect(): Float = {
        read(net.correct).asInstanceOf[ScalarFieldReader].read()
      }
    }

    val steps = 10000 / batchSize
    var lossAcc = 0f
    var correctAcc = 0f

    cg withRelease {
      cg.reset

      for (i <- 0 until steps) {
        lossAcc += cg.readLoss()
        correctAcc += cg.readCorrect()

        cg.step
      }
    }

    (lossAcc / steps, correctAcc / steps)
  }

  val cg = new ComputeGraph {
    val net = Net(netName, useRandomData = false, learningEnabled = true, batchSize = batchSize)

    probe(net.correct)
    probe(net.loss.forward)

    def readLoss(): Float = {
      read(net.loss.forward).asInstanceOf[ScalarFieldReader].read()
    }

    def readCorrect(): Float = {
      read(net.correct).asInstanceOf[ScalarFieldReader].read()
    }
  }

  cg withRelease {
    logger.info(s"starting compilation")
    cg.reset
    logger.info(s"compilation finished")

    val loss = cg.readLoss()
    val correct = cg.readCorrect()

    logger.info(s"initial loss: $loss")
    logger.info(s"initial accuracy: $correct")

    logger.info(s"Iteration: 0 Sample: 0 Training Loss: $loss Training Accuracy: $correct")

    for (i <- 1 to 50000) {
      cg.step

      if (i % 100 == 0) {
        val loss = cg.readLoss()
        val correct = cg.readCorrect()
        logger.info(s"Iteration: $i Sample: ${i * batchSize} Training Loss: $loss Training Accuracy: $correct")
      }

      if (i % 500 == 0) {
        logger.info(s"Validating...")
        val (loss, correct) = validate(cg.net.weights.snapshot(cg))
        logger.info(s"Iteration: $i Sample: ${i * batchSize} Validation Loss: $loss Validation Accuracy: $correct")
      }
    }
  }
} 
Example 141
Source File: Benchmark.scala    From cct-nn   with Apache License 2.0 5 votes vote down vote up
package toolkit.neuralnetwork.performance

import com.typesafe.scalalogging.StrictLogging
import toolkit.neuralnetwork.examples.networks.CIFAR

import scala.collection.mutable.ListBuffer
import libcog._


object Benchmark extends App with StrictLogging {
  val (net, batchSize) = args.length match {
    case 0 => ("cifar10_quick", 256)
    case 1 => (args(0), 256)
    case 2 => (args(0), args(1).toInt)
    case _ => throw new RuntimeException(s"illegal arguments (${args.toList})")
  }

  require(net == "cifar10_quick", s"network $net isn't supported")

  logger.info(s"net: $net")
  logger.info(s"batch size: $batchSize")

  val cg1 = new ComputeGraph {
    val net = new CIFAR(useRandomData = true, learningEnabled = false, batchSize = batchSize)
  }

  val forward = new ListBuffer[Double]()
  val backward = new ListBuffer[Double]()

  cg1 withRelease {
    logger.info(s"starting compilation (inference)")
    cg1.step
    logger.info(s"compilation finished (inference)")

    for (i <- 1 to 50) {
      val start = System.nanoTime()
      cg1.step
      val stop = System.nanoTime()
      val elapsed = (stop - start).toDouble / 1e6
      logger.info(s"Iteration: $i forward time: $elapsed ms.")
      forward += elapsed
    }
  }

  val cg2 = new ComputeGraph {
    val net = new CIFAR(useRandomData = true, learningEnabled = true, batchSize = batchSize)
  }

  cg2 withRelease {
    logger.info(s"starting compilation (learning)")
    cg2.step
    logger.info(s"compilation finished (learning)")

    for (i <- 1 to 50) {
      val start = System.nanoTime()
      cg2.step
      val stop = System.nanoTime()
      val elapsed = (stop - start).toDouble / 1e6
      logger.info(s"Iteration: $i forward-backward time: $elapsed ms.")
      backward += elapsed
    }
  }

  logger.info(s"Average Forward pass: ${forward.sum / forward.length} ms.")
  logger.info(s"Average Forward-Backward: ${backward.sum / backward.length} ms.")
} 
Example 142
Source File: SourceCodeInfo.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor

import com.typesafe.scalalogging.StrictLogging
import tutor.utils.FileUtil
import tutor.utils.FileUtil._

import scala.util.Try

final case class SourceCodeInfo(path: String, localPath: String, lineCount: Int)

object SourceCodeInfo {

  implicit object SourceCodeInfoOrdering extends Ordering[SourceCodeInfo] {
    override def compare(x: SourceCodeInfo, y: SourceCodeInfo): Int = x.lineCount compare y.lineCount
  }

}

trait SourceCodeAnalyzer extends StrictLogging {
  def processFile(path: Path): Try[SourceCodeInfo] = {
    import scala.io._
    Try {
      val source = Source.fromFile(path)
      try {
        val lines = source.getLines.toList
        SourceCodeInfo(path, FileUtil.extractLocalPath(path), lines.length)
      } catch {
        case e: Throwable => throw new IllegalArgumentException(s"error processing file $path", e)
      } finally {
        source.close()
      }
    }
  }
} 
Example 143
Source File: DirectoryScanner.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor

import java.io.File

import com.typesafe.scalalogging.StrictLogging
import tutor.utils.FileUtil
import tutor.utils.FileUtil.Path

trait DirectoryScanner extends StrictLogging {
  
  def scan(path: Path, knownFileTypes: Set[String], ignoreFolders: Set[String]): Seq[Path] = {
    scan(path)(Vector[Path](), ignoreFolders) {
      (acc, f) =>
        val filePath = f.getAbsolutePath
        if (f.isFile && shouldAccept(f.getPath, knownFileTypes)) {
          acc :+ filePath
        } else acc
    }
  }

  def scan[T](path: Path)(initValue: T, ignoreFolders: Set[String])(processFile: (T, File) => T): T = {
    val files = new File(path).listFiles()
    if (files == null) {
      logger.warn(s"$path is not a legal directory")
      initValue
    } else {
      files.foldLeft(initValue) { (acc, file) =>
        val filePath = file.getAbsolutePath
        if (file.isFile) {
          processFile(acc, file)
        } else if (file.isDirectory && (!ignoreFolders.contains(FileUtil.extractLocalPath(file.getPath)))) {
          scan(filePath)(acc, ignoreFolders)(processFile)
        } else {
          acc
        }
      }
    }
  }

  def foreachFile(path: Path, knownFileTypes: Set[String], ignoreFolders: Set[String])(processFile: File => Unit): Unit = {
    scan(path)((), ignoreFolders) {
      (acc, f) =>
        val filePath = f.getAbsolutePath
        if (f.isFile && shouldAccept(f.getPath, knownFileTypes)) {
          processFile(f)
        } else ()
    }
  }

  private def shouldAccept(path: Path, knownFileTypes: Set[String]): Boolean = {
    knownFileTypes.contains(FileUtil.extractExtFileName(path))
  }
} 
Example 144
Source File: BenchmarkUtil.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor.utils

import java.text.SimpleDateFormat
import java.util.Date

import com.typesafe.scalalogging.StrictLogging

object BenchmarkUtil extends StrictLogging {
  def record[T](actionDesc: String)(action: => T): T = {
    val beginTime = new Date
    logger.info(s"begin $actionDesc")
    val rs = action
    logger.info(s"end $actionDesc")
    val endTime = new Date
    val elapsed = new Date(endTime.getTime - beginTime.getTime)
    val sdf = new SimpleDateFormat("mm:ss.SSS")
    logger.info(s"$actionDesc total elapsed ${sdf.format(elapsed)}")
    rs
  }
  def recordStart(actionDesc: String):Date = {
    logger.info(s"$actionDesc begin")
    new Date
  }

  def recordElapse(actionDesc: String, beginFrom: Date):Unit = {
    logger.info(s"$actionDesc ended")
    val endTime = new Date
    val elapsed = new Date(endTime.getTime - beginFrom.getTime)
    val sdf = new SimpleDateFormat("mm:ss.SSS")
    logger.info(s"$actionDesc total elapsed ${sdf.format(elapsed)}")
  }
} 
Example 145
Source File: MainApp.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor

import java.io.File

import com.typesafe.scalalogging.StrictLogging
import tutor.PresetFilters.{ignoreFolders, knownFileTypes}
import tutor.repo.{AnalyzeHistoryRepository, H2DB}
import tutor.utils.FileUtil.Path
import tutor.utils.{BenchmarkUtil, WriteSupport}

object MainApp extends App with ReportFormatter with WriteSupport with StrictLogging {
  if (args.length < 1) {
    println("usage: CodeAnalyzer FilePath [-oOutputfile]")
  } else {
    val path: Path = args(0)
    val file = new File(path)
    val analyzer = args.find(_.startsWith("-p")).map { _ =>
      logger.info("using par collection mode")
      new CodebaseAnalyzerParImpl with DirectoryScanner with SourceCodeAnalyzer with AnalyzeHistoryRepository with H2DB
    }.getOrElse {
      logger.info("using sequence collection mode")
      new CodebaseAnalyzerSeqImpl with DirectoryScanner with SourceCodeAnalyzer with AnalyzeHistoryRepository with H2DB
    }
    val rs = if (file.isFile) {
      analyzer.processFile(file.getAbsolutePath).map(format).getOrElse(s"error processing $path")
    } else {
      BenchmarkUtil.record(s"analyze code under $path") {
        analyzer.analyze(path, knownFileTypes, ignoreFolders).map(format).getOrElse("not result found")
      }
    }
    args.find(_.startsWith("-o")).foreach { opt =>
      val output = opt.drop(2)
      withWriter(output) {
        _.write(rs)
      }
      println(s"report saved into $output")
    }
    println(rs)
  }

} 
Example 146
Source File: CodebaseAnalyzerStreamApp.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor

import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl._
import com.typesafe.scalalogging.StrictLogging
import tutor.utils.BenchmarkUtil

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future
import scala.util.{Failure, Success}

object CodebaseAnalyzerStreamApp extends App with DirectoryScanner with SourceCodeAnalyzer with ReportFormatter with StrictLogging {

  implicit val system = ActorSystem("CodebaseAnalyzer")
  implicit val materializer = ActorMaterializer()
  implicit val ec = system.dispatcher

  val path = args(0)
  val beginTime = BenchmarkUtil.recordStart(s"analyze $path with akka stream")
  val files = scan(path, PresetFilters.knownFileTypes, PresetFilters.ignoreFolders).iterator
  var errorProcessingFiles: ArrayBuffer[Throwable] = ArrayBuffer.empty

  val done = Source.fromIterator(() => files).mapAsync(8)(path => Future {
    processFile(path)
  }).fold(CodebaseInfo.empty) {
    (acc, trySourceCodeInfo) =>
      trySourceCodeInfo match {
        case Success(sourceCodeInfo) => acc + sourceCodeInfo
        case Failure(e) => {
          errorProcessingFiles += e
          acc
        }
      }
  }.runForeach(codebaseInfo => {
    println(format(codebaseInfo))
    println(s"there are ${errorProcessingFiles.size} files failed to process.")
  })
  done.onComplete { _ =>
    BenchmarkUtil.recordElapse(s"analyze $path with akka stream", beginTime)
    system.terminate()
  }
} 
Example 147
Source File: AkkaConfigPropertySourceAdapterSpec.scala    From akka-spring-boot   with Apache License 2.0 5 votes vote down vote up
package com.github.scalaspring.akka.config

import java.util

import com.typesafe.config.ConfigFactory
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.JavaConverters._

class AkkaConfigPropertySourceAdapterSpec extends FlatSpec with Matchers with StrictLogging {

  val goodProperties = textToProperties(
    """|list[0]=zero
       |list[1]=one
       |list[2]=two
       |normal=normal""".stripMargin
  )

  val badProperties = textToProperties(
    """|list[0]=zero
      |list[1]=one
      |list[2]=two
      |list=bad""".stripMargin
  )

  def textToProperties(text: String): java.util.Map[String, String] = {
    text.lines.map { line =>
      line.split('=') match {
        case Array(k, v) => (k, v)
        case _ => sys.error(s"invalid property format $line")
      }
    }.foldLeft(new java.util.LinkedHashMap[String, String]())((m, t) => { m.put(t._1, t._2); m })
  }

  def validateListProperty(list: java.util.List[String]): Unit = {
    list should have size 3
    list.get(0) shouldBe "zero"
    list.get(1) shouldBe "one"
    list.get(2) shouldBe "two"
  }
  
  "Indexed properties" should "be converted to a list" in {

    val converted: java.util.Map[String, AnyRef] = AkkaConfigPropertySourceAdapter.convertIndexedProperties(goodProperties)
    val list = converted.get("list").asInstanceOf[java.util.List[String]]

    converted.keySet should have size 2
    converted.get("normal") shouldBe "normal"

    validateListProperty(list)
  }

  "Overlapping (bad) property hierarchy" should "throw exception" in {

    an [IllegalArgumentException] should be thrownBy {
      AkkaConfigPropertySourceAdapter.convertIndexedProperties(badProperties)
    }

    // Exception should be thrown regardless of property order
    val reversed = new util.LinkedHashMap[String, String]()
    badProperties.entrySet().asScala.foreach { e => reversed.put(e.getKey, e.getValue) }

    an [IllegalArgumentException] should be thrownBy {
      AkkaConfigPropertySourceAdapter.convertIndexedProperties(reversed)
    }
  }

  "Akka Config" should "parse converted property map" in {
    val converted = AkkaConfigPropertySourceAdapter.convertIndexedProperties(goodProperties)
    val config = ConfigFactory.parseMap(converted)

    config.entrySet should have size 2
    config.hasPath("list") shouldBe true
    config.hasPath("normal") shouldBe true

    validateListProperty(config.getStringList("list"))
  }
} 
Example 148
Source File: AkkaConfigPropertySourceAdapterPatternSpec.scala    From akka-spring-boot   with Apache License 2.0 5 votes vote down vote up
package com.github.scalaspring.akka.config

import java.util.regex.Matcher

import com.typesafe.scalalogging.StrictLogging
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.{FlatSpec, Matchers}

class AkkaConfigPropertySourceAdapterPatternSpec extends FlatSpec with Matchers with StrictLogging {

  val indexed = Table(
    ("name",                      "path",                 "index"),
    ("x[0]",                      "x",                    0),
    ("someProperty[0]",           "someProperty",         0),
    ("some_property[1]",          "some_property",        1),
    ("some.property[0]",          "some.property",        0),
    (" some.property[0] ",        "some.property",        0),
    ("some.other.property[893]",  "some.other.property",  893)
  )

  val nonIndexed = Table(
    ("name"),
    ("x"),
    ("someProperty"),
    ("some_property"),
    ("some.property"),
    ("some.other.property")
  )

  "Indexed property regular expression" should "match indexed property names" in {
    forAll (indexed) { (name: String, path: String, index: Int) =>
      val m: Matcher = AkkaConfigPropertySourceAdapter.INDEXED_PROPERTY_PATTERN.matcher(name)
      m.matches() shouldBe true
      m.group("path") shouldEqual path
      m.group("index") shouldEqual index.toString
    }
  }

  it should "not match non-indexed property names" in {
    forAll (nonIndexed) { (name: String) =>
      val m: Matcher = AkkaConfigPropertySourceAdapter.INDEXED_PROPERTY_PATTERN.matcher(name)
      m.matches() shouldBe false
    }
  }

} 
Example 149
Source File: ResourceInjectionTest.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.router2

import javax.naming.ConfigurationException

import com.google.inject.Injector
import com.google.inject.ProvisionException
import com.typesafe.scalalogging.StrictLogging
import org.junit.Test

trait ResourceInjectionTest extends StrictLogging {
  def injector: Injector

  @Test
  def routerInjection(): Unit = {
    injector.getProvider(classOf[NaptimePlayRouter])
  }

  @Test
  def resourceInjection(): Unit = {
    val naptimePlayRouter = try {
      Some(injector.getInstance(classOf[NaptimePlayRouter]))
    } catch {
      case e: ConfigurationException =>
        logger.warn(s"No instance of 'NaptimePlayRouter' bound. Skipping router2 tests.", e)
        None
      case e: ProvisionException =>
        logger.error("Encountered an exception provisioning 'NaptimePlayRouter'.", e)
        None
    }

    for {
      router <- naptimePlayRouter
      resource <- router.naptimeRoutes.routerBuilders
    } {
      injector.getProvider(resource.resourceClass())
      logger.debug(s"Resource ${resource.resourceClass().getName} is injectable.")
    }
  }
} 
Example 150
Source File: SangriaGraphQlSchemaBuilder.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql

import com.linkedin.data.DataMap
import com.linkedin.data.schema.RecordDataSchema
import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.ari.graphql.schema.NaptimeTopLevelResourceField
import org.coursera.naptime.ari.graphql.schema.SchemaErrors
import org.coursera.naptime.ari.graphql.schema.SchemaMetadata
import org.coursera.naptime.ari.graphql.schema.WithSchemaErrors
import org.coursera.naptime.schema.Resource
import sangria.schema.Context
import sangria.schema.Schema
import sangria.schema.Value
import sangria.schema.Field
import sangria.schema.ObjectType

class SangriaGraphQlSchemaBuilder(resources: Set[Resource], schemas: Map[String, RecordDataSchema])
    extends StrictLogging {

  val schemaMetadata = SchemaMetadata(resources, schemas)

  
  def generateSchema(): WithSchemaErrors[Schema[SangriaGraphQlContext, DataMap]] = {
    val topLevelResourceObjectsAndErrors = resources.map { resource =>
      val lookupTypeAndErrors =
        NaptimeTopLevelResourceField.generateLookupTypeForResource(resource, schemaMetadata)
      val fields = lookupTypeAndErrors.data.flatMap { resourceObject =>
        if (resourceObject.fields.nonEmpty) {
          Some(
            Field.apply[SangriaGraphQlContext, DataMap, DataMap, Any](
              NaptimeTopLevelResourceField.formatResourceTopLevelName(resource),
              resourceObject,
              resolve = (_: Context[SangriaGraphQlContext, DataMap]) => {
                Value(new DataMap())
              }))
        } else {
          None
        }
      }
      lookupTypeAndErrors.copy(data = fields)
    }

    val topLevelResourceObjects = topLevelResourceObjectsAndErrors.flatMap(_.data)
    val schemaErrors = topLevelResourceObjectsAndErrors.foldLeft(SchemaErrors.empty)(_ ++ _.errors)

    val dedupedResources = topLevelResourceObjects.groupBy(_.name).map(_._2.head).toList
    val rootObject = ObjectType[SangriaGraphQlContext, DataMap](
      name = "root",
      description = "Top-level accessor for Naptime resources",
      fields = dedupedResources)

    WithSchemaErrors(Schema(rootObject), schemaErrors)
  }
} 
Example 151
Source File: MetricsCollectionMiddleware.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.controllers.middleware

import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import sangria.execution.BeforeFieldResult
import sangria.execution.MiddlewareErrorField
import sangria.execution.MiddlewareQueryContext
import sangria.schema.Context

class MetricsCollectionMiddleware(metricsCollector: GraphQLMetricsCollector)
    extends MiddlewareErrorField[SangriaGraphQlContext] {

  override type QueryVal = Unit
  override type FieldVal = Unit

  override def beforeQuery(context: MiddlewareQueryContext[SangriaGraphQlContext, _, _]): Unit = ()

  override def afterQuery(
      queryVal: Unit,
      context: MiddlewareQueryContext[SangriaGraphQlContext, _, _]): Unit = ()

  override def beforeField(
      queryVal: Unit,
      mctx: MiddlewareQueryContext[SangriaGraphQlContext, _, _],
      ctx: Context[SangriaGraphQlContext, _]): BeforeFieldResult[SangriaGraphQlContext, FieldVal] =
    BeforeFieldResult(Unit, None)

  override def fieldError(
      queryVal: QueryVal,
      fieldVal: FieldVal,
      error: Throwable,
      mctx: MiddlewareQueryContext[SangriaGraphQlContext, _, _],
      ctx: Context[SangriaGraphQlContext, _]): Unit = {
    val fieldAndParentName = ctx.parentType.name + ":" + ctx.field.name
    metricsCollector.markFieldError(fieldAndParentName)
  }
}

trait GraphQLMetricsCollector {
  def markFieldError(fieldName: String): Unit
  def timeQueryParsing[A](operationName: String)(f: => A): A
}

class LoggingMetricsCollector extends GraphQLMetricsCollector with StrictLogging {
  def markFieldError(fieldName: String): Unit = {
    logger.info(s"Error when loading field $fieldName")
  }

  def timeQueryParsing[A](operationName: String)(f: => A): A = {
    val before = System.currentTimeMillis()
    val res = f
    val after = System.currentTimeMillis()
    logger.info(s"Parsed query $operationName in ${after - before}ms")
    res
  }
}

class NoopMetricsCollector extends GraphQLMetricsCollector with StrictLogging {
  def markFieldError(fieldName: String): Unit = ()

  def timeQueryParsing[A](operationName: String)(f: => A): A = {
    f
  }
} 
Example 152
Source File: QueryComplexityFilter.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.controllers.filters

import javax.inject.Inject
import javax.inject.Singleton

import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.ari.Response
import org.coursera.naptime.ari.graphql.GraphqlSchemaProvider
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import org.coursera.naptime.ari.graphql.controllers.GraphQLController
import org.coursera.naptime.ari.graphql.marshaller.NaptimeMarshaller._
import org.coursera.naptime.ari.graphql.resolvers.NaptimeResolver
import org.coursera.naptime.ari.graphql.resolvers.NoopResolver
import play.api.libs.json.JsObject
import play.api.libs.json.Json
import play.api.mvc.Results
import sangria.ast.Document
import sangria.execution.ErrorWithResolver
import sangria.execution.Executor
import sangria.execution.QueryAnalysisError
import sangria.execution.QueryReducer

import scala.concurrent.ExecutionContext
import scala.concurrent.Future

@Singleton
class QueryComplexityFilter @Inject()(
    graphqlSchemaProvider: GraphqlSchemaProvider,
    configuration: ComplexityFilterConfiguration)(implicit executionContext: ExecutionContext)
    extends Filter
    with Results
    with StrictLogging {

  val MAX_COMPLEXITY = configuration.maxComplexity

  def apply(nextFilter: FilterFn): FilterFn = { incoming =>
    computeComplexity(incoming.document, incoming.variables)
      .flatMap { complexity =>
        if (complexity > MAX_COMPLEXITY) {
          Future.successful(
            OutgoingQuery(
              response = Json.obj("error" -> "Query is too complex.", "complexity" -> complexity),
              ariResponse = None))
        } else {
          nextFilter.apply(incoming)
        }
      }
      .recover {
        case error: QueryAnalysisError =>
          OutgoingQuery(error.resolveError.as[JsObject], None)
        case error: ErrorWithResolver =>
          OutgoingQuery(error.resolveError.as[JsObject], None)
        case error: Exception =>
          OutgoingQuery(Json.obj("errors" -> Json.arr(error.getMessage)), None)
      }
  }

  private[graphql] def computeComplexity(queryAst: Document, variables: JsObject)(
      implicit executionContext: ExecutionContext): Future[Double] = {
    // TODO(bryan): is there a way around this var?
    var complexity = 0D
    val complReducer = QueryReducer.measureComplexity[SangriaGraphQlContext] { (c, ctx) =>
      complexity = c
      ctx
    }
    val executorFut = Executor.execute(
      graphqlSchemaProvider.schema,
      queryAst,
      SangriaGraphQlContext(null, null, executionContext, debugMode = false),
      variables = variables,
      exceptionHandler = GraphQLController.exceptionHandler(logger),
      queryReducers = List(complReducer),
      deferredResolver = new NoopResolver())

    executorFut.map { _ =>
      complexity
    }

  }
}

case class ComplexityFilterConfiguration(maxComplexity: Int)

object ComplexityFilterConfiguration {
  val DEFAULT = ComplexityFilterConfiguration(100000)
} 
Example 153
Source File: NaptimeRecordField.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.schema

import com.linkedin.data.DataMap
import com.linkedin.data.schema.RecordDataSchema
import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.ResourceName
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import sangria.schema.Field
import sangria.schema.ObjectType
import sangria.schema.StringType
import sangria.schema.Value

import scala.collection.JavaConverters._

object NaptimeRecordField extends StrictLogging {

  private[schema] def build(
      schemaMetadata: SchemaMetadata,
      recordDataSchema: RecordDataSchema,
      fieldName: String,
      namespace: Option[String],
      resourceName: ResourceName,
      currentPath: List[String]) = {

    Field.apply[SangriaGraphQlContext, DataMapWithParent, Any, Any](
      name = FieldBuilder.formatName(fieldName),
      fieldType = getType(
        schemaMetadata,
        recordDataSchema,
        namespace,
        resourceName,
        currentPath :+ fieldName),
      resolve = context => {
        context.value.element.get(fieldName) match {
          case dataMap: DataMap =>
            context.value.copy(element = dataMap)
          case other: Any =>
            logger.warn(s"Expected DataMap but got $other")
            Value(null)
          case null =>
            Value(null)
        }

      })
  }

  private[schema] def getType(
      schemaMetadata: SchemaMetadata,
      recordDataSchema: RecordDataSchema,
      namespace: Option[String],
      resourceName: ResourceName,
      currentPath: List[String]): ObjectType[SangriaGraphQlContext, DataMapWithParent] = {

    val formattedResourceName = NaptimeResourceUtils.formatResourceName(resourceName)
    ObjectType[SangriaGraphQlContext, DataMapWithParent](
      FieldBuilder.formatName(s"${formattedResourceName}_${recordDataSchema.getFullName}"),
      recordDataSchema.getDoc,
      fieldsFn = () => {
        val fields = recordDataSchema.getFields.asScala.map { field =>
          FieldBuilder.buildField(
            schemaMetadata,
            field,
            namespace,
            resourceName = resourceName,
            currentPath = currentPath)
        }.toList
        if (fields.isEmpty) {
          // TODO(bryan): Handle this case better
          EMPTY_FIELDS_FALLBACK
        } else {
          fields
        }
      })
  }

  val EMPTY_FIELDS_FALLBACK = List(
    Field.apply[SangriaGraphQlContext, DataMapWithParent, Any, Any](
      "ArbitraryField",
      StringType,
      resolve = context => null))

} 
Example 154
Source File: NaptimePaginationField.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari.graphql.schema

import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.PaginationConfiguration
import org.coursera.naptime.ResourceName
import org.coursera.naptime.ResponsePagination
import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import sangria.schema.Argument
import sangria.schema.Field
import sangria.schema.IntType
import sangria.schema.LongType
import sangria.schema.ObjectType
import sangria.schema.OptionInputType
import sangria.schema.OptionType
import sangria.schema.StringType

object NaptimePaginationField extends StrictLogging {

  def getField(
      resourceName: ResourceName,
      fieldName: String): ObjectType[SangriaGraphQlContext, ResponsePagination] = {

    ObjectType[SangriaGraphQlContext, ResponsePagination](
      name = "ResponsePagination",
      fields = List(
        Field.apply[SangriaGraphQlContext, ResponsePagination, Any, Any](
          name = "next",
          fieldType = OptionType(StringType),
          resolve = _.value.next),
        Field.apply[SangriaGraphQlContext, ResponsePagination, Any, Any](
          name = "total",
          fieldType = OptionType(LongType),
          resolve = context => {
            context.value match {
              case responsePagination: ResponsePagination =>
                responsePagination.total
              case null =>
                logger.error("Expected ResponsePagination but got null")
                None
            }
          })))
  }

  private[graphql] val limitArgument = Argument(
    name = "limit",
    argumentType = OptionInputType(IntType),
    defaultValue = PaginationConfiguration().defaultLimit,
    description = "Maximum number of results to include in response")

  private[graphql] val startArgument = Argument(
    name = "start",
    argumentType = OptionInputType(StringType),
    description = "Cursor to start pagination at")

  val paginationArguments = List(limitArgument, startArgument)

} 
Example 155
Source File: LocalSchemaProvider.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.ari

import javax.inject.Inject

import com.linkedin.data.schema.DataSchema
import com.linkedin.data.schema.RecordDataSchema
import com.typesafe.scalalogging.StrictLogging
import org.coursera.naptime.ResourceName
import org.coursera.naptime.router2.NaptimeRoutes
import org.coursera.naptime.schema.Resource


class LocalSchemaProvider @Inject()(naptimeRoutes: NaptimeRoutes)
    extends SchemaProvider
    with StrictLogging {

  private[this] val resourceSchemaMap: Map[ResourceName, Resource] =
    naptimeRoutes.schemaMap.flatMap {
      // TODO: handle sub resources
      case (_, schema)
          if schema.parentClass.isEmpty ||
            schema.parentClass.contains("org.coursera.naptime.resources.RootResource") =>
        val resourceName = ResourceName(schema.name, version = schema.version.getOrElse(0L).toInt)
        Some(resourceName -> schema)

      case (_, schema) =>
        logger.warn(s"Cannot handle nested resource $schema")
        None
    }

  private[this] val mergedTypes = naptimeRoutes.routerBuilders
    .flatMap(_.types.map(_.tuple))
    .filter(_._2.isInstanceOf[RecordDataSchema])
    .map(tuple => tuple._1 -> tuple._2.asInstanceOf[RecordDataSchema])
    .toMap

  override val fullSchema: FullSchema =
    FullSchema(
      naptimeRoutes.schemaMap.values.toSet,
      naptimeRoutes.routerBuilders.flatMap(_.types.map(_.value)))

  override def mergedType(resourceName: ResourceName): Option[RecordDataSchema] = {
    resourceSchemaMap.get(resourceName).flatMap { schema =>
      mergedTypes.get(schema.mergedType)
    }
  }
} 
Example 156
Source File: Authenticator.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.access.authenticator

import com.typesafe.scalalogging.StrictLogging
import org.coursera.common.concurrent.Futures
import org.coursera.naptime.NaptimeActionException
import org.coursera.naptime.access.authenticator.combiner.And
import org.coursera.naptime.access.authenticator.combiner.AnyOf
import org.coursera.naptime.access.authenticator.combiner.FirstOf
import play.api.http.Status.FORBIDDEN
import play.api.http.Status.UNAUTHORIZED
import play.api.mvc.RequestHeader

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.util.control.NonFatal


  def maybeAuthenticate(requestHeader: RequestHeader)(
      implicit ec: ExecutionContext): Future[Option[Either[NaptimeActionException, A]]]

  def collect[B](f: PartialFunction[A, B]): Authenticator[B] = {
    val self = this
    new Authenticator[B] {
      override def maybeAuthenticate(requestHeader: RequestHeader)(
          implicit ec: ExecutionContext): Future[Option[Either[NaptimeActionException, B]]] = {

        Futures
          .safelyCall(self.maybeAuthenticate(requestHeader))
          .map(_.map(_.right.map(f.lift)))
          .map {
            case Some(Right(None))    => None
            case Some(Right(Some(b))) => Some(Right(b))
            case Some(Left(error))    => Some(Left(error))
            case None                 => None
          }
          .recover(Authenticator.errorRecovery)
      }
    }
  }

  def map[B](f: A => B): Authenticator[B] = collect(PartialFunction(f))

}

object Authenticator extends StrictLogging with AnyOf with FirstOf with And {

  def apply[P, A](
      parser: HeaderAuthenticationParser[P],
      decorator: Decorator[P, A]): Authenticator[A] = {

    new Authenticator[A] {
      def maybeAuthenticate(requestHeader: RequestHeader)(
          implicit ec: ExecutionContext): Future[Option[Either[NaptimeActionException, A]]] = {

        parser.parseHeader(requestHeader) match {
          case ParseResult.Success(parsed) =>
            Futures
              .safelyCall(decorator(parsed))
              .map { either =>
                either.left
                  .map { message =>
                    Some(Left(NaptimeActionException(FORBIDDEN, None, Some(message))))
                  }
                  .right
                  .map { decorated =>
                    Some(Right(decorated))
                  }
                  .merge
              }
              .recover(errorRecovery)
          case ParseResult.Error(message, status) =>
            Future.successful(
              Some(Left(NaptimeActionException(status, Some("auth.parse"), Some(message)))))
          case ParseResult.Skip => Future.successful(None)
        }
      }
    }

  }

  private[access] def authenticateAndRecover[A](
      authenticator: Authenticator[A],
      requestHeader: RequestHeader)(
      implicit ec: ExecutionContext): Future[Option[Either[NaptimeActionException, A]]] = {
    Futures
      .safelyCall(authenticator.maybeAuthenticate(requestHeader))
      .recover(errorRecovery)
  }

  def errorRecovery[A]: PartialFunction[Throwable, Option[Either[NaptimeActionException, A]]] = {
    case NonFatal(e) =>
      logger.error("Unexpected authentication error", e)
      val message = s"Unexpected authentication error: ${e.getMessage}"
      Some(Left(NaptimeActionException(UNAUTHORIZED, Some("auth.perms"), Some(message))))
  }

} 
Example 157
Source File: CourierQueryParsers.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.router2

import java.io.IOException

import com.linkedin.data.DataMap
import com.linkedin.data.schema.DataSchema
import com.linkedin.data.schema.validation.CoercionMode
import com.linkedin.data.schema.validation.RequiredMode
import com.linkedin.data.schema.validation.ValidateDataAgainstSchema
import com.linkedin.data.schema.validation.ValidationOptions
import com.typesafe.scalalogging.StrictLogging
import org.coursera.courier.codecs.InlineStringCodec
import org.coursera.naptime.courier.StringKeyCodec
import play.api.mvc.RequestHeader

object CourierQueryParsers extends StrictLogging {

  import CollectionResourceRouter.errorRoute

  private[this] val validationOptions =
    new ValidationOptions(RequiredMode.FIXUP_ABSENT_WITH_DEFAULT, CoercionMode.STRING_TO_PRIMITIVE)

  private[this] def parseStringToDataMap(
      paramName: String,
      schema: DataSchema,
      resourceClass: Class[_])(value: String): Either[RouteAction, DataMap] = {
    try {
      val parsed = if (value.startsWith("(") && value.endsWith(")")) {
        InlineStringCodec.instance.bytesToMap(value.getBytes("UTF-8"))
      } else {
        val codec = new StringKeyCodec(schema)
        codec.bytesToMap(value.getBytes("UTF-8"))
      }
      val validated = ValidateDataAgainstSchema.validate(parsed, schema, validationOptions)
      if (validated.isValid) {
        Right(validated.getFixed.asInstanceOf[DataMap])
      } else {
        logger.warn(
          s"${resourceClass.getName}: Bad query parameter for parameter " +
            s"'$paramName': $value. Errors: ${validated.getMessages}")
        Left(errorRoute(s"Improperly formatted value for parameter '$paramName'", resourceClass))
      }
    } catch {
      case ioException: IOException =>
        logger.warn(
          s"${resourceClass.getName}: Bad query parameter for parameter " +
            s"'$paramName': $value. Errors: ${ioException.getMessage}")
        Left(errorRoute(s"Improperly formatted value for parameter '$paramName'", resourceClass))
    }
  }

  def strictParse(
      paramName: String,
      schema: DataSchema,
      resourceClass: Class[_],
      rh: RequestHeader): Either[RouteAction, DataMap] = {
    val queryStringResults = rh.queryString.get(paramName)
    if (queryStringResults.isEmpty || queryStringResults.get.isEmpty) {
      Left(errorRoute(s"Missing required parameter '$paramName'", resourceClass))
    } else if (queryStringResults.get.tail.isEmpty) {
      val stringValue = queryStringResults.get.head
      parseStringToDataMap(paramName, schema, resourceClass)(stringValue)
    } else {
      Left(errorRoute(s"Too many query parameters for '$paramName", resourceClass))
    }
  }

  def optParse(
      paramName: String,
      schema: DataSchema,
      resourceClass: Class[_],
      rh: RequestHeader): Either[RouteAction, Option[DataMap]] = {
    val queryStringResults = rh.queryString.get(paramName)
    if (queryStringResults.isEmpty || queryStringResults.get.isEmpty) {
      Right(None)
    } else if (queryStringResults.get.tail.isEmpty) {
      val stringValue = queryStringResults.get.head
      parseStringToDataMap(paramName, schema, resourceClass)(stringValue).right.map(Some(_))
    } else {
      Left(errorRoute(s"Too many query parameters for '$paramName", resourceClass))
    }
  }

  // TODO: Add a 'QTry' query parameter type that will attempt to parse the query parameter but
  // instead of failing, will provide the valiation errors to the resource handler to do with what
  // they want.
} 
Example 158
Source File: AttributesProvider.scala    From naptime   with Apache License 2.0 5 votes vote down vote up
package org.coursera.naptime.router2

import com.typesafe.scalalogging.StrictLogging
import org.coursera.courier.templates.DataTemplates.DataConversion
import org.coursera.naptime.courier.CourierFormats
import org.coursera.naptime.schema.Attribute
import org.coursera.naptime.schema.JsValue
import play.api.libs.json.JsError
import play.api.libs.json.JsObject
import play.api.libs.json.JsSuccess
import play.api.libs.json.Json

import scala.util.control.NonFatal

object AttributesProvider extends StrictLogging {

  val SCALADOC_ATTRIBUTE_NAME = "scaladocs"

  lazy val scaladocs: Map[String, JsObject] = {
    val scaladocPath = "/naptime.scaladoc.json"
    (for {
      stream <- Option(getClass.getResourceAsStream(scaladocPath))
      json <- try {
        Some(Json.parse(stream))
      } catch {
        case NonFatal(exception) =>
          logger.warn(
            s"Could not parse contents of file " +
              s"$scaladocPath as JSON")
          None
      } finally {
        stream.close()
      }
      scaladocCollection <- json.validate[Map[String, JsObject]] match {
        case JsSuccess(deserialized, _) =>
          Some(deserialized)
        case JsError(_) =>
          logger.warn(
            s"Could not deserialize contents of file " +
              s"$scaladocPath as `Map[String, JsObject]`")
          None
      }
    } yield {
      scaladocCollection
    }).getOrElse(Map.empty)
  }

  def getResourceAttributes(className: String): Seq[Attribute] = {
    scaladocs
      .get(className)
      .map(value => Attribute(SCALADOC_ATTRIBUTE_NAME, Some(jsObjToJsValue(value))))
      .toList
  }

  def getMethodAttributes(className: String, methodName: String): Seq[Attribute] = {
    scaladocs
      .get(s"$className.$methodName")
      .map(value => Attribute(SCALADOC_ATTRIBUTE_NAME, Some(jsObjToJsValue(value))))
      .toList
  }

  private[this] def jsObjToJsValue(jsObj: JsObject): JsValue = {
    JsValue.build(CourierFormats.objToDataMap(jsObj), DataConversion.SetReadOnly)
  }
} 
Example 159
Source File: SqlStatement.scala    From gatling-sql   with Apache License 2.0 5 votes vote down vote up
package io.github.gatling.sql

import java.sql.{Connection, PreparedStatement}

import com.typesafe.scalalogging.StrictLogging
import io.github.gatling.sql.db.ConnectionPool
import io.gatling.commons.validation.Validation
import io.gatling.core.session.{Expression, Session}
import io.gatling.commons.validation._

trait SqlStatement extends StrictLogging {

  def apply(session:Session): Validation[PreparedStatement]

  def connection = ConnectionPool.connection
}

case class SimpleSqlStatement(statement: Expression[String]) extends SqlStatement {
  def apply(session: Session): Validation[PreparedStatement] = statement(session).flatMap { stmt =>
      logger.debug(s"STMT: ${stmt}")
      connection.prepareStatement(stmt).success
    }
} 
Example 160
Source File: SetSessionScala.scala    From akka-http-session   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.example.session

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.server.Directives._
import akka.stream.ActorMaterializer
import com.softwaremill.session.CsrfDirectives._
import com.softwaremill.session.CsrfOptions._
import com.softwaremill.session.SessionDirectives._
import com.softwaremill.session.SessionOptions._
import com.softwaremill.session._
import com.typesafe.scalalogging.StrictLogging

import scala.io.StdIn

object SetSessionScala extends App with StrictLogging {
  implicit val system = ActorSystem("example")
  implicit val materializer = ActorMaterializer()

  import system.dispatcher

  val sessionConfig = SessionConfig.default(
    "c05ll3lesrinf39t7mc5h6un6r0c69lgfno69dsak3vabeqamouq4328cuaekros401ajdpkh60rrtpd8ro24rbuqmgtnd1ebag6ljnb65i8a55d482ok7o0nch0bfbe")
  implicit val sessionManager = new SessionManager[MyScalaSession](sessionConfig)
  implicit val refreshTokenStorage = new InMemoryRefreshTokenStorage[MyScalaSession] {
    def log(msg: String) = logger.info(msg)
  }

  def mySetSession(v: MyScalaSession) = setSession(refreshable, usingCookies, v)

  val myRequiredSession = requiredSession(refreshable, usingCookies)
  val myInvalidateSession = invalidateSession(refreshable, usingCookies)

  val routes =
    randomTokenCsrfProtection(checkHeader) {
      pathPrefix("api") {
        path("do_login") {
          post {
            entity(as[String]) { body =>
              logger.info(s"Logging in $body")
              mySetSession(MyScalaSession(body)) {
                setNewCsrfToken(checkHeader) { ctx =>
                  ctx.complete("ok")
                }
              }
            }
          }
        }
      }
    }

  val bindingFuture = Http().bindAndHandle(routes, "localhost", 8080)

  println("Server started, press enter to stop. Visit http://localhost:8080 to see the demo.")
  StdIn.readLine()

  import system.dispatcher

  bindingFuture
    .flatMap(_.unbind())
    .onComplete { _ =>
      system.terminate()
      println("Server stopped")
    }
} 
Example 161
Source File: SessionInvalidationScala.scala    From akka-http-session   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.example.session

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.server.Directives._
import akka.stream.ActorMaterializer
import com.softwaremill.session.SessionDirectives._
import com.softwaremill.session.SessionOptions._
import com.softwaremill.session._
import com.typesafe.scalalogging.StrictLogging

import scala.io.StdIn

object SessionInvalidationScala extends App with StrictLogging {
  implicit val system = ActorSystem("example")
  implicit val materializer = ActorMaterializer()

  import system.dispatcher

  val sessionConfig = SessionConfig.default(
    "c05ll3lesrinf39t7mc5h6un6r0c69lgfno69dsak3vabeqamouq4328cuaekros401ajdpkh60rrtpd8ro24rbuqmgtnd1ebag6ljnb65i8a55d482ok7o0nch0bfbe")
  implicit val sessionManager = new SessionManager[MyScalaSession](sessionConfig)
  implicit val refreshTokenStorage = new InMemoryRefreshTokenStorage[MyScalaSession] {
    def log(msg: String) = logger.info(msg)
  }

  def mySetSession(v: MyScalaSession) = setSession(refreshable, usingCookies, v)

  val myRequiredSession = requiredSession(refreshable, usingCookies)
  val myInvalidateSession = invalidateSession(refreshable, usingCookies)

  val routes =
    path("logout") {
      post {
        myRequiredSession { session =>
          myInvalidateSession { ctx =>
            logger.info(s"Logging out $session")
            ctx.complete("ok")
          }
        }
      }
    }

  val bindingFuture = Http().bindAndHandle(routes, "localhost", 8080)

  println("Server started, press enter to stop. Visit http://localhost:8080 to see the demo.")
  StdIn.readLine()

  import system.dispatcher

  bindingFuture
    .flatMap(_.unbind())
    .onComplete { _ =>
      system.terminate()
      println("Server stopped")
    }
} 
Example 162
Source File: ScalaExample.scala    From akka-http-session   with Apache License 2.0 5 votes vote down vote up
package com.softwaremill.example

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.server.Directives._
import akka.stream.ActorMaterializer
import com.softwaremill.example.session.MyScalaSession
import com.softwaremill.session.CsrfDirectives._
import com.softwaremill.session.CsrfOptions._
import com.softwaremill.session.SessionDirectives._
import com.softwaremill.session.SessionOptions._
import com.softwaremill.session._
import com.typesafe.scalalogging.StrictLogging

import scala.io.StdIn

object Example extends App with StrictLogging {
  implicit val system = ActorSystem("example")
  implicit val materializer = ActorMaterializer()

  import system.dispatcher

  val sessionConfig = SessionConfig.default(
    "c05ll3lesrinf39t7mc5h6un6r0c69lgfno69dsak3vabeqamouq4328cuaekros401ajdpkh60rrtpd8ro24rbuqmgtnd1ebag6ljnb65i8a55d482ok7o0nch0bfbe")
  implicit val sessionManager = new SessionManager[MyScalaSession](sessionConfig)
  implicit val refreshTokenStorage = new InMemoryRefreshTokenStorage[MyScalaSession] {
    def log(msg: String) = logger.info(msg)
  }

  def mySetSession(v: MyScalaSession) = setSession(refreshable, usingCookies, v)

  val myRequiredSession = requiredSession(refreshable, usingCookies)
  val myInvalidateSession = invalidateSession(refreshable, usingCookies)

  val routes =
    path("") {
      redirect("/site/index.html", Found)
    } ~
      randomTokenCsrfProtection(checkHeader) {
        pathPrefix("api") {
          path("do_login") {
            post {
              entity(as[String]) { body =>
                logger.info(s"Logging in $body")

                mySetSession(MyScalaSession(body)) {
                  setNewCsrfToken(checkHeader) { ctx =>
                    ctx.complete("ok")
                  }
                }
              }
            }
          } ~
            // This should be protected and accessible only when logged in
            path("do_logout") {
              post {
                myRequiredSession { session =>
                  myInvalidateSession { ctx =>
                    logger.info(s"Logging out $session")
                    ctx.complete("ok")
                  }
                }
              }
            } ~
            // This should be protected and accessible only when logged in
            path("current_login") {
              get {
                myRequiredSession { session => ctx =>
                  logger.info("Current session: " + session)
                  ctx.complete(session.username)
                }
              }
            }
        } ~
          pathPrefix("site") {
            getFromResourceDirectory("")
          }
      }

  val bindingFuture = Http().bindAndHandle(routes, "localhost", 8080)

  println("Server started, press enter to stop. Visit http://localhost:8080 to see the demo.")
  StdIn.readLine()

  import system.dispatcher

  bindingFuture
    .flatMap(_.unbind())
    .onComplete { _ =>
      system.terminate()
      println("Server stopped")
    }
} 
Example 163
Source File: SparkApplicationTester.scala    From TopNotch   with Apache License 2.0 5 votes vote down vote up
package com.bfm.topnotch

import org.scalatest.OneInstancePerTest
import org.apache.hadoop.hbase.CellUtil
import org.apache.hadoop.hbase.client.{HConnection, HTableInterface, Put}
import org.apache.spark._
import org.apache.spark.sql.SparkSession
import org.scalamock.scalatest.MockFactory
import org.scalatest.FlatSpec
import com.typesafe.scalalogging.StrictLogging

/**
 * This class handles some of the boilerplate of testing SparkApplications with HBase writers
 */
abstract class SparkApplicationTester extends FlatSpec with OneInstancePerTest with MockFactory with StrictLogging
  with SharedSparkContext {
  protected val hconn = mock[HConnection]
  lazy val spark = SparkSession
    .builder()
    .appName(getClass.getName)
    .master("local")
    .config("spark.sql.shuffle.partitions", "4")
    //setting this to false to emulate HiveQL's case insensitivity for column names
    .config("spark.sql.caseSensitive", "false")
    .getOrCreate()

  /**
   * Verify that the next HTable will receive the correct puts. Call this once per HTable that is supposed to be created and written to.
   * Note: All HBase tests for a SparkApplication object must be run sequentially in order for us to keep track of HTableInterface mocks
   * @param tests The test's expected name for the HTable and expected values for the Put objects placed in the HTable
   * @param acceptAnyPut Tells the mock to accept any put value. This is useful for tests using the mock and but not
   *                     testing what is put inside it.
   */
  def setHBaseMock(tests: HTableParams, acceptAnyPut: Boolean = false): Unit = {
    val table = mock[HTableInterface]
    inSequence {
      (hconn.getTable(_: String)).expects(tests.tableName).returning(table)
      inAnyOrder {
        for (correctPut <- tests.puts) {
          if (acceptAnyPut) {
            (table.put(_: Put)).expects(*)
          }
          else {
            (table.put(_: Put)).expects(where {
              (actualPut: Put) =>
                val actualValue = CellUtil.cloneValue(actualPut.get(correctPut.columnFamily, correctPut.columnQualifier).get(0))
                correctPut.valueTest(actualValue)
                // just return true, as if issues, will have exception thrown by value test
                true
            })
          }
        }
      }
      (table.close _).expects().returns()
    }
  }

  /**
    * Set the next HTable will accept anything accept anything. This is useful if testing a thing that needs an hbase
    * table, but the specific test isn't testing the hbase functionality.
    *
    * @param tableName the name of the table that will be accessed.
    */
  def allowAnyHBaseActions(tableName: String): Unit ={
    setHBaseMock(new HTableParams(tableName, Seq(null)), true)
  }

  /**
   * The set of parameters defining what values should be used to create the HTable
   * @param tableName The name of the table the test expects to be created
   * @param puts The list of parameters for the puts that the test expects to be placed in the table
   */
  case class HTableParams(
                           tableName: String,
                           puts: Seq[HPutParams]
                           )

  /**
   * The list of values that the test expects to be in a put.
   * @param row The name of the row to put into HBase
   * @param columnFamily The cell's column family
   * @param columnQualifier The cell's column qualifier
   * @param correctString A string representing the correct value or an error message
   * @param valueTest The method for checking if the value in the cell is correct. Done as the actual and intended values
   *                  in a cell may be equal even if they don't have the expression as an array of bytes.
    *                 This should throw an exception on failure, using a call like shouldBe
   */
  case class HPutParams(
                         row: Array[Byte],
                         columnFamily: Array[Byte],
                         columnQualifier: Array[Byte],
                         correctString: String,
                         valueTest: Array[Byte] => Unit
                         )
} 
Example 164
Source File: GamePacketDecoder.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.game

import java.util

import wowchat.common.{ByteUtils, Packet}
import com.typesafe.scalalogging.StrictLogging
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.ByteToMessageDecoder

class GamePacketDecoder extends ByteToMessageDecoder with GamePackets with StrictLogging {

  protected val HEADER_LENGTH = 4

  private var size = 0
  private var id = 0

  override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: util.List[AnyRef]): Unit = {
    if (in.readableBytes < HEADER_LENGTH) {
      return
    }

    val crypt = ctx.channel.attr(CRYPT).get

    if (size == 0 && id == 0) {
      // decrypt if necessary
      val tuple = if (crypt.isInit) {
        parseGameHeaderEncrypted(in, crypt)
      } else {
        parseGameHeader(in)
      }
      id = tuple._1
      size = tuple._2
    }

    if (size > in.readableBytes) {
      return
    }

    val byteBuf = in.readBytes(size)

    // decompress if necessary
    val (newId, decompressed) = decompress(id, byteBuf)

    val packet = Packet(newId, decompressed)

    logger.debug(f"RECV PACKET: $newId%04X - ${ByteUtils.toHexString(decompressed, true, false)}")

    out.add(packet)
    size = 0
    id = 0
  }

  def parseGameHeader(in: ByteBuf): (Int, Int) = {
    val size = in.readShort - 2
    val id = in.readShortLE
    (id, size)
  }

  def parseGameHeaderEncrypted(in: ByteBuf, crypt: GameHeaderCrypt): (Int, Int) = {
    val header = new Array[Byte](HEADER_LENGTH)
    in.readBytes(header)
    val decrypted = crypt.decrypt(header)
    val size = ((decrypted(0) & 0xFF) << 8 | decrypted(1) & 0xFF) - 2
    val id = (decrypted(3) & 0xFF) << 8 | decrypted(2) & 0xFF
    (id, size)
  }

  // vanilla has no compression. starts in cata/mop
  def decompress(id: Int, in: ByteBuf): (Int, ByteBuf) = {
    (id, in)
  }
} 
Example 165
Source File: GamePacketEncoder.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.game

import wowchat.common.{ByteUtils, Packet}
import com.typesafe.scalalogging.StrictLogging
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.MessageToByteEncoder

import scala.collection.mutable.ArrayBuffer

class GamePacketEncoder extends MessageToByteEncoder[Packet] with GamePackets with StrictLogging {

  override def encode(ctx: ChannelHandlerContext, msg: Packet, out: ByteBuf): Unit = {
    val crypt = ctx.channel.attr(CRYPT).get
    val unencrypted = isUnencryptedPacket(msg.id)

    val headerSize = if (unencrypted) 4 else 6

    val array = new ArrayBuffer[Byte](headerSize)
    array ++= ByteUtils.shortToBytes(msg.byteBuf.writerIndex + headerSize - 2)
    array ++= ByteUtils.shortToBytesLE(msg.id)
    val header = if (unencrypted) {
      array.toArray
    } else {
      array.append(0, 0)
      crypt.encrypt(array.toArray)
    }

    logger.debug(f"SEND PACKET: ${msg.id}%04X - ${ByteUtils.toHexString(msg.byteBuf, true, false)}")

    out.writeBytes(header)
    out.writeBytes(msg.byteBuf)
    msg.byteBuf.release
  }

  protected def isUnencryptedPacket(id: Int): Boolean = {
    id == CMSG_AUTH_CHALLENGE
  }
} 
Example 166
Source File: RealmPacketDecoder.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.realm

import java.util

import wowchat.common.{ByteUtils, Packet, WowChatConfig, WowExpansion}
import com.typesafe.scalalogging.StrictLogging
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.ByteToMessageDecoder

class RealmPacketDecoder extends ByteToMessageDecoder with StrictLogging {

  private var size = 0
  private var id = 0

  override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: util.List[AnyRef]): Unit = {
    if (in.readableBytes == 0) {
      return
    }

    if (size == 0 && id == 0) {
      in.markReaderIndex
      id = in.readByte
      id match {
        case RealmPackets.CMD_AUTH_LOGON_CHALLENGE =>
          if (in.readableBytes < 2) {
            in.resetReaderIndex
            return
          }

          in.markReaderIndex
          in.skipBytes(1)
          val result = in.readByte
          size = if (RealmPackets.AuthResult.isSuccess(result)) {
            118
          } else {
            2
          }
          in.resetReaderIndex
        case RealmPackets.CMD_AUTH_LOGON_PROOF =>
          if (in.readableBytes < 1) {
            in.resetReaderIndex
            return
          }

          // size is error dependent
          in.markReaderIndex
          val result = in.readByte
          size = if (RealmPackets.AuthResult.isSuccess(result)) {
            if (WowChatConfig.getExpansion == WowExpansion.Vanilla) 25 else 31
          } else {
            if (WowChatConfig.getExpansion == WowExpansion.Vanilla) 1 else 3
          }
          in.resetReaderIndex
        case RealmPackets.CMD_REALM_LIST =>
          if (in.readableBytes < 2) {
            in.resetReaderIndex
            return
          }

          size = in.readShortLE
      }
    }

    if (size > in.readableBytes) {
      return
    }

    val byteBuf = in.readBytes(size)
    val packet = Packet(id, byteBuf)

    logger.debug(f"RECV REALM PACKET: $id%04X - ${ByteUtils.toHexString(byteBuf, true, false)}")

    out.add(packet)
    size = 0
    id = 0
  }
} 
Example 167
Source File: RealmConnector.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.realm

import java.net.InetSocketAddress
import java.util.concurrent.TimeUnit

import wowchat.common._
import com.typesafe.scalalogging.StrictLogging
import io.netty.bootstrap.Bootstrap
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.channel.{Channel, ChannelInitializer, ChannelOption}
import io.netty.handler.timeout.IdleStateHandler
import io.netty.util.concurrent.Future

import scala.util.Try

class RealmConnector(realmConnectionCallback: RealmConnectionCallback) extends StrictLogging {

  private var channel: Option[Channel] = None
  private var connected: Boolean = false

  def connect: Unit = {
    logger.info(s"Connecting to realm server ${Global.config.wow.realmlist.host}:${Global.config.wow.realmlist.port}")

    val bootstrap = new Bootstrap
    bootstrap.group(Global.group)
      .channel(classOf[NioSocketChannel])
      .option[java.lang.Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000)
      .option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
      .remoteAddress(new InetSocketAddress(Global.config.wow.realmlist.host, Global.config.wow.realmlist.port))
      .handler(new ChannelInitializer[SocketChannel]() {

        @throws[Exception]
        override protected def initChannel(socketChannel: SocketChannel): Unit = {
          val handler = if (WowChatConfig.getExpansion == WowExpansion.Vanilla) {
            new RealmPacketHandler(realmConnectionCallback)
          } else {
            new RealmPacketHandlerTBC(realmConnectionCallback)
          }

          socketChannel.pipeline.addLast(
            new IdleStateHandler(60, 120, 0),
            new IdleStateCallback,
            new RealmPacketDecoder,
            new RealmPacketEncoder,
            handler
          )
        }
      })

    channel = Some(bootstrap.connect.addListener((future: Future[_ >: Void]) => {
      Try {
        future.get(10, TimeUnit.SECONDS)
      }.fold(throwable => {
        logger.error(s"Failed to connect to realm server! ${throwable.getMessage}")
        realmConnectionCallback.disconnected
      }, _ => Unit)
    }).channel)
  }
} 
Example 168
Source File: WoWChat.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat

import java.util.concurrent.{Executors, TimeUnit}

import wowchat.common.{CommonConnectionCallback, Global, ReconnectDelay, WowChatConfig}
import wowchat.discord.Discord
import wowchat.game.GameConnector
import wowchat.realm.{RealmConnectionCallback, RealmConnector}
import com.typesafe.scalalogging.StrictLogging
import io.netty.channel.nio.NioEventLoopGroup

import scala.io.Source

object WoWChat extends StrictLogging {

  private val RELEASE = "v1.3.3"

  def main(args: Array[String]): Unit = {
    logger.info(s"Running WoWChat - $RELEASE")
    val confFile = if (args.nonEmpty) {
      args(0)
    } else {
      logger.info("No configuration file supplied. Trying with default wowchat.conf.")
      "wowchat.conf"
    }
    Global.config = WowChatConfig(confFile)

    checkForNewVersion

    val gameConnectionController: CommonConnectionCallback = new CommonConnectionCallback {

      private val reconnectExecutor = Executors.newSingleThreadScheduledExecutor
      private val reconnectDelay = new ReconnectDelay

      override def connect: Unit = {
        Global.group = new NioEventLoopGroup

        val realmConnector = new RealmConnector(new RealmConnectionCallback {
          override def success(host: String, port: Int, realmName: String, realmId: Int, sessionKey: Array[Byte]): Unit = {
            gameConnect(host, port, realmName, realmId, sessionKey)
          }

          override def disconnected: Unit = doReconnect

          override def error: Unit = sys.exit(1)
        })

        realmConnector.connect
      }

      private def gameConnect(host: String, port: Int, realmName: String, realmId: Int, sessionKey: Array[Byte]): Unit = {
        new GameConnector(host, port, realmName, realmId, sessionKey, this).connect
      }

      override def connected: Unit = reconnectDelay.reset

      override def disconnected: Unit = doReconnect

      def doReconnect: Unit = {
        Global.group.shutdownGracefully()
        Global.discord.changeRealmStatus("Connecting...")
        val delay = reconnectDelay.getNext
        logger.info(s"Disconnected from server! Reconnecting in $delay seconds...")

        reconnectExecutor.schedule(new Runnable {
          override def run(): Unit = connect
        }, delay, TimeUnit.SECONDS)
      }
    }

    logger.info("Connecting to Discord...")
    Global.discord = new Discord(new CommonConnectionCallback {
      override def connected: Unit = gameConnectionController.connect

      override def error: Unit = sys.exit(1)
    })
  }

  private def checkForNewVersion = {
    // This is JSON, but I really just didn't want to import a full blown JSON library for one string.
    val data = Source.fromURL("https://api.github.com/repos/fjaros/wowchat/releases/latest").mkString
    val regex = "\"tag_name\":\"(.+?)\",".r
    val repoTagName = regex
      .findFirstMatchIn(data)
      .map(_.group(1))
      .getOrElse("NOT FOUND")

    if (repoTagName != RELEASE) {
      logger.error( "~~~ !!!                YOUR WoWChat VERSION IS OUT OF DATE                !!! ~~~")
      logger.error(s"~~~ !!!                     Current Version:  $RELEASE                      !!! ~~~")
      logger.error(s"~~~ !!!                     Repo    Version:  $repoTagName                      !!! ~~~")
      logger.error( "~~~ !!! RUN git pull OR GO TO https://github.com/fjaros/wowchat TO UPDATE !!! ~~~")
      logger.error( "~~~ !!!                YOUR WoWChat VERSION IS OUT OF DATE                !!! ~~~")
    }
  }
} 
Example 169
Source File: ReconnectDelay.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.common

import com.typesafe.scalalogging.StrictLogging

class ReconnectDelay extends StrictLogging {

  private var reconnectDelay: Option[Int] = None

  def reset: Unit = {
    reconnectDelay = None
  }

  def getNext: Int = {
    synchronized {
      reconnectDelay = Some(10)

      val result = reconnectDelay.get
      logger.debug(s"GET RECONNECT DELAY $result")
      result
    }
  }
} 
Example 170
Source File: IdleStateCallback.scala    From wowchat   with GNU General Public License v3.0 5 votes vote down vote up
package wowchat.common

import com.typesafe.scalalogging.StrictLogging
import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter}
import io.netty.handler.timeout.{IdleState, IdleStateEvent}

class IdleStateCallback extends ChannelInboundHandlerAdapter with StrictLogging {

  override def userEventTriggered(ctx: ChannelHandlerContext, evt: scala.Any): Unit = {
    evt match {
      case event: IdleStateEvent =>
        val idler = event.state match {
          case IdleState.READER_IDLE => "reader"
          case IdleState.WRITER_IDLE => "writer"
          case _ => "all"
        }
        logger.error(s"Network state for $idler marked as idle!")
        ctx.close
      case _ =>
    }

    super.userEventTriggered(ctx, evt)
  }
} 
Example 171
Source File: AutoScaling.scala    From ionroller   with MIT License 5 votes vote down vote up
package ionroller.aws

import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.autoscaling.model._
import com.amazonaws.services.autoscaling.{AmazonAutoScaling, AmazonAutoScalingClient}
import com.amazonaws.services.elasticloadbalancing.model.InstanceState
import com.typesafe.scalalogging.StrictLogging

import scala.collection.JavaConverters._
import scala.concurrent.duration.FiniteDuration
import scalaz._
import scalaz.concurrent.Task

object AutoScaling extends StrictLogging {
  val client: Kleisli[Task, AWSCredentialsProvider, AmazonAutoScaling] = {
    Kleisli { credentialsProvider =>
      Task(new AmazonAutoScalingClient(credentialsProvider))(awsExecutorService)
    }
  }

  def getAutoScalingGroupDetails(asgs: Seq[String]): Kleisli[Task, AWSClientCache, List[AutoScalingGroup]] = {
    Kleisli { client =>

      def go(asgsSoFar: List[AutoScalingGroup], token: Option[String]): Task[List[AutoScalingGroup]] = Task.delay {
        val req =
          if (asgs.isEmpty)
            new DescribeAutoScalingGroupsRequest()
          else
            new DescribeAutoScalingGroupsRequest().withAutoScalingGroupNames(asgs: _*)

        token foreach { t => req.setNextToken(t) }
        val response = client.asg.describeAutoScalingGroups(req)
        (asgsSoFar ::: response.getAutoScalingGroups.asScala.toList, Option(response.getNextToken))
      } flatMap {
        case (events, t @ Some(newToken)) if t != token =>
          logger.debug(s"Needed multiple getAutoScalingGroups calls, token=$token newToken=$t")
          go(events, t)
        case (events, _) =>
          Task.now(events)
      }

      Task.fork(go(List.empty, None))(awsExecutorService)
    }
  }

  def getUnregisteredInstance(elbInstances: Seq[InstanceState], asg: AutoScalingGroup): Option[String] = {
    val lbInstances = elbInstances.map(_.getInstanceId).toSet
    val asgInstances = asg.getInstances.asScala
    val healthyAsgInstances = asgInstances.filter(_.getHealthStatus == "Healthy").map(_.getInstanceId).toSet

    (healthyAsgInstances -- lbInstances).toSeq.sorted.headOption
  }

  def attachElb(asg: AutoScalingGroup, lb: String): Kleisli[Task, AWSClientCache, AttachLoadBalancersResult] = {
    Kleisli { cache =>
      val attachRequest = new AttachLoadBalancersRequest().withAutoScalingGroupName(asg.getAutoScalingGroupName).withLoadBalancerNames(lb)
      Task(cache.asg.attachLoadBalancers(attachRequest))(awsExecutorService)
    }
  }

  def detachElb(asg: AutoScalingGroup, lb: String): Kleisli[Task, AWSClientCache, DetachLoadBalancersResult] = {
    Kleisli { cache =>
      val detachRequest = new DetachLoadBalancersRequest().withAutoScalingGroupName(asg.getAutoScalingGroupName).withLoadBalancerNames(lb)
      Task(cache.asg.detachLoadBalancers(detachRequest))(awsExecutorService)
    }
  }

  def updateElbHealthCheck(asgs: Seq[String], healthCheckType: String, gracePeriod: FiniteDuration): Kleisli[Task, AWSClientCache, Unit] = {
    Kleisli { cache =>

      for {
        _ <- Task.gatherUnordered(
          asgs map { name =>
            Task(
              cache.asg.updateAutoScalingGroup(
                new UpdateAutoScalingGroupRequest()
                  .withAutoScalingGroupName(name)
                  .withHealthCheckType(healthCheckType)
                  .withHealthCheckGracePeriod(gracePeriod.toSeconds.toInt)
              )
            )(awsExecutorService)
          }
        )
      } yield ()
    }
  }
} 
Example 172
Source File: ConfigurationManager.scala    From ionroller   with MIT License 5 votes vote down vote up
package ionroller

import com.amazonaws.services.elasticbeanstalk.model.ConfigurationOptionSetting
import com.typesafe.config.{Config, ConfigException, ConfigFactory}
import com.typesafe.scalalogging.StrictLogging
import ionroller.aws.Dynamo
import play.api.libs.json._

import scalaz.concurrent.Task
import scalaz.stream._
import scalaz.{-\/, \/-}

class ConfigurationManager(initialConfig: SystemConfiguration) extends StrictLogging {
  val configurationSignal = async.signalOf(initialConfig)
}

object ConfigurationManager extends StrictLogging {
  def apply(initialConfig: SystemConfiguration): ConfigurationManager =
    new ConfigurationManager(initialConfig)

  val confFile: Config = ConfigFactory.load()
  val defaultSolutionStack: String = confFile.getString("ionroller.solution-stack-name")
  val whitelistKey = "ionroller.modify-environments-whitelist"
  val blacklistKey = "ionroller.modify-environments-blacklist"

  val modifyEnvironmentslist: (String, Boolean) => Set[TimelineName] = {
    (key, required) =>
      try {
        confFile.getString(key)
          .split(",")
          .map(_.trim)
          .collect({ case s: String if !s.isEmpty && s != "ALL" => s })
          .map(TimelineName.apply).toSet
      } catch {
        case ex: ConfigException.Missing => {
          if (required) {
            logger.error(s"${key} $required configuration is missing.\nRun ION-Roller with property: -D${key}=[ALL|<TIMELINE_NAME_1,TIMELINE_NAME_2,...>]")
            throw ex
          } else Set.empty
        }
      }
  }

  val modifyEnvironmentsWhitelist = modifyEnvironmentslist(whitelistKey, true)
  val modifyEnvironmentsBlacklist = modifyEnvironmentslist(blacklistKey, false)

  logger.debug("Processing timelines: " + {
    if (modifyEnvironmentsWhitelist.isEmpty) "ALL" else modifyEnvironmentsWhitelist
  }) + {
    if (!modifyEnvironmentsBlacklist.isEmpty) "excluding: " + modifyEnvironmentsBlacklist else ""
  }
  val modifyEnvironments = confFile.getBoolean("ionroller.modify-environments")

  val defaultOptionSettings: Seq[ConfigurationOptionSetting] = {
    val options = confFile.getString("ionroller.option-settings")

    Task(Json.parse(options).as[Seq[ConfigurationOptionSetting]]).attemptRun match {
      case \/-(settings) => settings
      case -\/(t) => {
        logger.error(t.getMessage, t)
        Seq.empty
      }
    }
  }

  val defaultResources: JsObject = {
    val resources = confFile.getString("ionroller.resources")

    Task(Json.parse(resources).as[JsObject]).attemptRun match {
      case \/-(resources) => resources
      case -\/(t) => {
        logger.error(t.getMessage, t)
        JsObject(Seq.empty)
      }
    }
  }

  def getSavedConfiguration: Task[SystemConfiguration] = {
    for {
      table <- Dynamo.configTable(None)
      systemConfig <- Dynamo.getSystemConfig(table)
    } yield systemConfig
  }

} 
Example 173
Source File: package.scala    From ionroller   with MIT License 5 votes vote down vote up
import java.util.concurrent.{ExecutorService, Executors, ScheduledExecutorService}

import com.amazonaws.services.elasticbeanstalk.model.ConfigurationOptionSetting
import com.typesafe.scalalogging.StrictLogging
import ionroller.aws.Dynamo
import ionroller.tracking.Event
import play.api.libs.functional.syntax._
import play.api.libs.json._

import scala.concurrent.duration.FiniteDuration
import scalaz.concurrent.Task
import scalaz.{-\/, \/-}

package object ionroller extends StrictLogging {
  val ionrollerExecutorService: ExecutorService = Executors.newFixedThreadPool(4)

  implicit val `| Implicit executor service        |`: ExecutorService = ionrollerExecutorService
  implicit val ` | is disabled - define explicitly  |`: ExecutorService = ionrollerExecutorService

  implicit val timer: ScheduledExecutorService = scalaz.concurrent.Strategy.DefaultTimeoutScheduler

  def ionrollerRole(awsAccountId: String) = s"arn:aws:iam::$awsAccountId:role/ionroller"

  implicit lazy val finiteDurationFormat = {

    def applyFiniteDuration(l: Long, u: String): FiniteDuration = {
      FiniteDuration(l, u.toLowerCase)
    }

    def unapplyFiniteDuration(d: FiniteDuration): (Long, String) = {
      (d.length, d.unit.toString)
    }

    ((JsPath \ "length").format[Long] and
      (JsPath \ "unit").format[String])(applyFiniteDuration, unapplyFiniteDuration)
  }

  implicit lazy val configurationOptionSettingFormat: Format[ConfigurationOptionSetting] = {
    def applyConfigOptionSetting(ns: String, optionName: String, value: String) =
      new ConfigurationOptionSetting(ns, optionName, value)

    def unapplyConfigOptionSetting(o: ConfigurationOptionSetting): Option[(String, String, String)] = {
      for {
        ns <- Option(o.getNamespace)
        n <- Option(o.getOptionName)
        v <- Option(o.getValue)
      } yield (ns, n, v)
    }

    ((JsPath \ "Namespace").format[String] and
      (JsPath \ "OptionName").format[String] and
      (JsPath \ "Value").format[String])(applyConfigOptionSetting _, unlift(unapplyConfigOptionSetting))
  }

  def enabled(name: TimelineName) = {
    ConfigurationManager.modifyEnvironments &&
      (ConfigurationManager.modifyEnvironmentsWhitelist.isEmpty || ConfigurationManager.modifyEnvironmentsWhitelist.contains(name)) &&
      !ConfigurationManager.modifyEnvironmentsBlacklist.contains(name)
  }

  def logEvent(evt: Event) = {
    logger.info(s"$evt (enabled = ${enabled(evt.service)})")
    if (enabled(evt.service))
      Dynamo.EventLogger.log(evt)
        .flatMap({
          case \/-(s) => Task.now(())
          case -\/(f) => Task.delay(logger.error(f.getMessage, f))
        })
    else Task.now(())
  }

} 
Example 174
Source File: TestServer.scala    From gatling-grpc   with Apache License 2.0 5 votes vote down vote up
package com.github.phisgr.example

import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import com.github.phisgr.example.greet._
import com.github.phisgr.example.util._
import com.typesafe.scalalogging.StrictLogging
import io.grpc._
import io.grpc.health.v1.health.HealthCheckResponse.ServingStatus.SERVING
import io.grpc.health.v1.health.HealthGrpc.Health
import io.grpc.health.v1.health.{HealthCheckRequest, HealthCheckResponse, HealthGrpc}
import io.grpc.stub.StreamObserver

import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.util.{Random, Try}

object TestServer extends StrictLogging {
  def startServer(): Server = {

    val accounts: collection.concurrent.Map[String, String] = new ConcurrentHashMap[String, String]().asScala

    val greetService = new GreetServiceGrpc.GreetService {
      override def greet(request: HelloWorld) = Future.fromTry(Try {
        val token = Option(TokenContextKey.get).getOrElse {
          val trailers = new Metadata()
          trailers.put(ErrorResponseKey, CustomError("You are not authenticated!"))
          throw Status.UNAUTHENTICATED.asException(trailers)
        }

        val username = request.username
        if (!accounts.get(username).contains(token)) throw Status.PERMISSION_DENIED.asException()
        ChatMessage(username = username, data = s"Server says: Hello ${request.name}!")
      })

      override def register(request: RegisterRequest) = Future.fromTry(Try {
        val token = new Random().alphanumeric.take(10).mkString
        val success = accounts.putIfAbsent(request.username, token).isEmpty

        if (success) {
          RegisterResponse(
            username = request.username,
            token = token
          )
        } else {
          val trailers = new Metadata()
          trailers.put(ErrorResponseKey, CustomError("The username is already taken!"))
          throw Status.ALREADY_EXISTS.asException(trailers)
        }
      })
    }

    // normally, it just adds the "token" header, if any, to the context
    // but for demo purpose, it fails the call with 0.001% chance
    val interceptor = new ServerInterceptor {
      override def interceptCall[ReqT, RespT](
        call: ServerCall[ReqT, RespT], headers: Metadata,
        next: ServerCallHandler[ReqT, RespT]
      ): ServerCall.Listener[ReqT] = {
        if (new Random().nextInt(100000) == 0) {
          val trailers = new Metadata()
          trailers.put(ErrorResponseKey, CustomError("1 in 100,000 chance!"))
          call.close(Status.UNAVAILABLE.withDescription("You're unlucky."), trailers)
          new ServerCall.Listener[ReqT] {}
        } else {
          val context = Context.current()
          val newContext = Option(headers.get(TokenHeaderKey)).fold(context)(context.withValue(TokenContextKey, _))
          Contexts.interceptCall(newContext, call, headers, next)
        }

      }
    }

    val port = 8080
    val server = ServerBuilder.forPort(port)
      .addService(GreetServiceGrpc.bindService(greetService, scala.concurrent.ExecutionContext.global))
      .intercept(interceptor)
      .build.start
    logger.info(s"Server started, listening on $port")

    server
  }

  def startEmptyServer() = {
    val service = new Health {
      override def check(request: HealthCheckRequest): Future[HealthCheckResponse] =
        Future.successful(HealthCheckResponse(SERVING))
      override def watch(request: HealthCheckRequest, responseObserver: StreamObserver[HealthCheckResponse]): Unit =
        responseObserver.onError(Status.UNIMPLEMENTED.asRuntimeException())
    }
    ServerBuilder.forPort(9999)
      .addService(HealthGrpc.bindService(service, scala.concurrent.ExecutionContext.global))
      .build.start
  }

  def main(args: Array[String]): Unit = {
    val server = startServer()
    server.awaitTermination(10, TimeUnit.MINUTES)
  }
} 
Example 175
Source File: Main.scala    From ForestFlow   with Apache License 2.0 5 votes vote down vote up
package ai.forestflow.event.subscribers

import java.net.InetAddress

import akka.Done
import akka.actor.ActorSystem
import akka.actor.CoordinatedShutdown.PhaseBeforeActorSystemTerminate
import akka.cluster.Cluster
import akka.management.cluster.bootstrap.ClusterBootstrap
import akka.management.scaladsl.AkkaManagement
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging

import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success, Try}

object Main extends StrictLogging {
  def main(args: Array[String]): Unit = {
    import ai.forestflow.startup.ActorSystemStartup._

    preStartup(typeSafeConfig)

    logger.info(s"Started system: [$system], cluster.selfAddress = ${cluster.selfAddress}")

    shutdown.addTask(PhaseBeforeActorSystemTerminate, "main.cleanup") { () => cleanup(typeSafeConfig) }

    bootstrapCluster(system, cluster)

    logger.info(s"Sharding lease owner for this node will be set to: ${cluster.selfAddress.hostPort}")

    // Start application after self member joined the cluster (Up)
    cluster.registerOnMemberUp({
      logger.info(s"Cluster Member is up: ${cluster.selfMember.toString()}")
      postStartup
    })

  }

  private def bootstrapCluster(system: ActorSystem, cluster: Cluster): Unit = {
    // Akka Management hosts the HTTP routes used by bootstrap
    AkkaManagement(system).start()

    // Starting the bootstrap process needs to be done explicitly
    ClusterBootstrap(system).start()

    system.log.info(s"Akka Management hostname from InetAddress.getLocalHost.getHostAddress is: ${InetAddress.getLocalHost.getHostAddress}")
  }

  private def preStartup(config: Config): Unit = {

  }

  private def postStartup(implicit system: ActorSystem, config: Config): Unit = {
    // Kafka Prediction Logger setup
    import system.log

    val basic_topic = Try(config.getString("application.kafka-prediction-logger.basic-topic")).toOption
    val gp_topic = Try(config.getString("application.kafka-prediction-logger.graphpipe-topic")).toOption

    if (basic_topic.isDefined || gp_topic.isDefined){
      log.info(s"Setting up Kafka prediction logging with basic_topic: $basic_topic graphpipe_topic: $gp_topic")
      val predictionLogger = system.actorOf(PredictionLogger.props(basic_topic, gp_topic))
    }
  }

  private def cleanup(config: Config)(implicit executionContext: ExecutionContext) = Future {
    Done
  }
} 
Example 176
Source File: ApplicationEnvironment.scala    From ForestFlow   with Apache License 2.0 5 votes vote down vote up
package ai.forestflow.serving.config

import java.util.concurrent.TimeUnit

import com.typesafe.config.{Config, ConfigFactory, ConfigRenderOptions}
import com.typesafe.scalalogging.StrictLogging

import scala.util.Try

//noinspection TypeAnnotation
object ApplicationEnvironment extends StrictLogging {

  def getConfig(env: Option[String]): Config = {
    val base = ConfigFactory.load()
    val defaults = base.getConfig("defaults")
    env match {
      case Some(envName) => base.getConfig(envName) withFallback defaults
      case None => defaults
    }
  }

  private lazy val applicationEnvironment = Try(ConfigFactory.load().getString("application.environment")).toOption
  logger.info(s"Application environment: $applicationEnvironment")
  lazy val config: Config = getConfig(applicationEnvironment)
  logger.debug(config.root().render(ConfigRenderOptions.concise()))

  lazy val SYSTEM_NAME = config.getString("application.system-name")

  lazy val MAX_NUMBER_OF_SHARDS = {
    val maxShards = config.getInt("application.max-number-of-shards")
    require(maxShards >= 1, "max-number-of-shards must be >= 1")
    maxShards
  }
  lazy val HTTP_COMMAND_TIMEOUT_SECS = {
    val duration = config.getDuration("application.http-command-timeout", TimeUnit.SECONDS)
    require(duration > 1, "http-command-timeout cannot be less than 1 second")
    duration
  }
  lazy val HTTP_PORT = config.getInt("application.http-port")
  lazy val HTTP_BIND_ADDRESS = Try(config.getString("http-bind-address")).getOrElse("0.0.0.0")

} 
Example 177
Source File: FairPhaseInPctBasedRouter.scala    From ForestFlow   with Apache License 2.0 5 votes vote down vote up
package ai.forestflow.serving.impl.ContractRouters

import ai.forestflow.serving.impl.ServableMetricsImpl
import ai.forestflow.serving.interfaces.{ContractRouter, RouterType}
import ai.forestflow.serving.interfaces.{ContractRouter, RouterType}
import ai.forestflow.utils.RFWeightedCollection.WeightedItem
import ai.forestflow.domain.FQRV
import ai.forestflow.serving.impl.ServableMetricsImpl
import ai.forestflow.utils.RFWeightedCollection
import com.typesafe.scalalogging.StrictLogging

import scala.collection.mutable.ArrayBuffer


trait FairPhaseInPctBasedRouterImpl {
  this: RouterType =>
  override def create(servableStats: List[(FQRV, ServableMetricsImpl)]): ContractRouter = FairPhaseInPctBasedRouter.create(servableStats)
}


  def apply(servableStats: List[(FQRV, ServableMetricsImpl)]): FairPhaseInPctBasedRouter = {
    logger.debug(s"Creating RFWeightedCollection from $servableStats")
    FairPhaseInPctBasedRouter(RFWeightedCollection(getWeightedItems(servableStats))())
  }
}

@SerialVersionUID(0L)
final case class FairPhaseInPctBasedRouter(private val collection: RFWeightedCollection[FQRV]) extends ContractRouter {
  import FairPhaseInPctBasedRouter._

  override def next(): Option[FQRV] = collection.next().map(_.item)

  override def merge(servableStats: List[(FQRV, ServableMetricsImpl)]): ContractRouter = {
    // If only weights are being updated to existing collection list
    if (collection.items.map(_.item).toSet == servableStats.map(_._1).toSet) {
      FairPhaseInPctBasedRouter(collection.updateWeights(getWeightedItems(servableStats)))
    }
    else // new or deleted items, create anew.
      FairPhaseInPctBasedRouter(servableStats)
  }
} 
Example 178
Source File: LatestPhaseInPctBasedRouter.scala    From ForestFlow   with Apache License 2.0 5 votes vote down vote up
package ai.forestflow.serving.impl.ContractRouters

import ai.forestflow.serving.impl.ServableMetricsImpl
import ai.forestflow.serving.interfaces.{ContractRouter, RouterType}
import ai.forestflow.domain.FQRV
import ai.forestflow.serving.impl.ServableMetricsImpl
import ai.forestflow.serving.interfaces.{ContractRouter, RouterType}
import ai.forestflow.utils.RFWeightedCollection
import ai.forestflow.utils.RFWeightedCollection.WeightedItem
import com.typesafe.scalalogging.StrictLogging

import scala.collection.mutable.ArrayBuffer


trait LatestPhaseInPctBasedRouterImpl {
  this: RouterType =>
  override def create(servableStats: List[(FQRV, ServableMetricsImpl)]): ContractRouter = LatestPhaseInPctBasedRouter.create(servableStats)
}


  def apply(servableStats: List[(FQRV, ServableMetricsImpl)]): LatestPhaseInPctBasedRouter = {
    logger.debug(s"Creating RFWeightedCollection from $servableStats")
    LatestPhaseInPctBasedRouter(RFWeightedCollection(getWeightedItems(getFilteredServableStats(servableStats)))())
  }
}

@SerialVersionUID(2965108754223283566L)
final case class LatestPhaseInPctBasedRouter(private val collection: RFWeightedCollection[FQRV]) extends ContractRouter {
  import LatestPhaseInPctBasedRouter._

  override def next(): Option[FQRV] = collection.next().map(_.item)

  override def merge(servableStats: List[(FQRV, ServableMetricsImpl)]): ContractRouter = {
    // If only weights are being updated to existing collection list
    val filteredServables = getFilteredServableStats(servableStats)
    if (collection.items.map(_.item).toSet == filteredServables.map(_._1).toSet) {
      LatestPhaseInPctBasedRouter(collection.updateWeights(getWeightedItems(filteredServables)))
    }
    else // new or deleted items, create anew.
      LatestPhaseInPctBasedRouter(filteredServables)
  }
} 
Example 179
Source File: MqttComponents.scala    From gatling-mqtt-protocol   with Apache License 2.0 5 votes vote down vote up
package com.github.jeanadrien.gatling.mqtt.protocol

import akka.actor.{ActorRef, ActorSystem}
import com.github.jeanadrien.gatling.mqtt.client.{FuseSourceMqttClient, MqttClient}
import com.typesafe.scalalogging.StrictLogging
import io.gatling.commons.validation.Validation
import io.gatling.core.protocol.ProtocolComponents
import io.gatling.core.session._
import org.fusesource.mqtt.client.CallbackConnection


case class MqttComponents(
    mqttProtocol : MqttProtocol, system : ActorSystem
) extends ProtocolComponents with StrictLogging {

    def mqttEngine(
        session : Session, connectionSettings : ConnectionSettings, gatlingMqttId : String
    ) : Validation[ActorRef] = {
        logger.debug(s"MqttComponents: new mqttEngine: ${gatlingMqttId}")
        mqttProtocol.configureMqtt(session).map { config =>
            // inject the selected engine
            val mqttClient = system.actorOf(MqttClient.clientInjection(config, gatlingMqttId))
            mqttClient
        }
    }

    override def onStart : Option[(Session) => Session] = Some(s => {
        logger.debug("MqttComponents: onStart");
        s
    })

    override def onExit : Option[(Session) => Unit] = Some(s => {
        logger.debug("MqttComponents: onExit");
        s("engine").asOption[ActorRef].foreach { mqtt =>
            system.stop(mqtt)
        }
    })
} 
Example 180
Source File: FuseSourceConnectionListener.scala    From gatling-mqtt-protocol   with Apache License 2.0 5 votes vote down vote up
package com.github.jeanadrien.gatling.mqtt.client

import akka.actor.ActorRef
import com.typesafe.scalalogging.StrictLogging
import org.fusesource.hawtbuf.{Buffer, UTF8Buffer}
import org.fusesource.mqtt.client.Listener


class FuseSourceConnectionListener(actor : ActorRef) extends Listener with StrictLogging {

    override def onPublish(
        topic : UTF8Buffer, body : Buffer,
        ack   : Runnable
    ) : Unit = {
        val topicStr = topic.toString()
        val bodyStr = body.toByteArray()

        logger.trace(s"Listener receives: topic=${topicStr}, body=${bodyStr}")

        actor ! MqttCommands.OnPublish(topicStr, bodyStr)
        ack.run()
    }

    override def onConnected() : Unit = {
        logger.debug(s"Client is now connected.")
    }

    override def onFailure(value : Throwable) : Unit = {
        logger.error(s"Listener: onFailure: ${value.getMessage}")
    }

    override def onDisconnected() : Unit = {
        logger.debug(s"Client has been disconnected.")
    }
} 
Example 181
Source File: AtlasJob.scala    From comet-data-pipeline   with Apache License 2.0 5 votes vote down vote up
package com.ebiznext.comet.job.atlas

import com.ebiznext.comet.config.Settings
import com.ebiznext.comet.schema.handlers.StorageHandler
import com.ebiznext.comet.schema.model.atlas.AtlasModel
import com.typesafe.scalalogging.StrictLogging

class AtlasJob(
  cliConfig: AtlasConfig,
  storageHandler: StorageHandler
)(implicit settings: Settings)
    extends StrictLogging {

  def run(): Unit = {
    logger.info(s"")
    val uris = cliConfig.uris.map(_.toArray).getOrElse(Array(settings.comet.atlas.uri))
    val userPassword = (cliConfig.user, cliConfig.password) match {
      case (Some(user), Some(pwd)) => Array(user, pwd)
      case _                       => Array(settings.comet.atlas.user, settings.comet.atlas.password)
    }
    new AtlasModel(uris, userPassword).run(cliConfig, storageHandler)
  }
} 
Example 182
Source File: SparkEnv.scala    From comet-data-pipeline   with Apache License 2.0 5 votes vote down vote up
package com.ebiznext.comet.config

import java.time.LocalDateTime
import java.time.format.DateTimeFormatter

import com.typesafe.scalalogging.StrictLogging
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession


  lazy val session: SparkSession = {
    val session =
      if (settings.comet.hive)
        SparkSession.builder.config(config).enableHiveSupport().getOrCreate()
      else
        SparkSession.builder.config(config).getOrCreate()
    logger.info("Spark Version -> " + session.version)
    logger.info(session.conf.getAll.mkString("\n"))
    session
  }

} 
Example 183
Source File: MultiServiceSpec.scala    From akka-http-spring-boot   with Apache License 2.0 5 votes vote down vote up
package com.github.scalaspring.akka.http

import java.net.ServerSocket

import akka.http.scaladsl.client.RequestBuilding._
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
import com.github.scalaspring.scalatest.TestContextManagement
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.SpringApplicationContextLoader
import org.springframework.context.annotation.{Bean, Import}
import org.springframework.test.context.ContextConfiguration
import resource._

import scala.concurrent.duration._

@ContextConfiguration(
  loader = classOf[SpringApplicationContextLoader],
  classes = Array(classOf[MultiServiceSpec.Configuration])
)
class MultiServiceSpec extends FlatSpec with TestContextManagement with AkkaStreamsAutowiredImplicits with Matchers with ScalaFutures with StrictLogging {

  implicit val patience = PatienceConfig((10.seconds))    // Allow time for server startup

  @Autowired val settings: ServerSettings = null
  @Autowired val client: HttpClient = null

  "Echo service" should "echo" in {
    val name = "name"
    val future = client.request(Get(s"http://${settings.interface}:${settings.port}/multi/echo/$name"))

    whenReady(future) { response =>
      response.status shouldBe OK
      whenReady(Unmarshal(response.entity).to[String])(_ shouldBe name)
    }
  }

  "Reverse service" should "reverse" in {
    val name = "name"
    val future = client.request(Get(s"http://${settings.interface}:${settings.port}/multi/reverse/$name"))

    whenReady(future) { response =>
      response.status shouldBe OK
      whenReady(Unmarshal(response.entity).to[String])(_ shouldBe name.reverse)
    }
  }

}

object MultiServiceSpec {

  @Configuration
  @Import(Array(classOf[AkkaHttpServerAutoConfiguration]))
  class Configuration extends AkkaHttpServer with EchoService with ReverseService {
    @Bean
    def serverSettings = new ServerSettings(port = managed(new ServerSocket(0)).map(_.getLocalPort).opt.get)
  }

  trait EchoService extends AkkaHttpService {
    abstract override def route: Route = {
      (get & path("multi"/"echo"/Segment)) { name =>
        complete(name)
      } ~ super.route
    }
  }

  trait ReverseService extends AkkaHttpService {
    abstract override def route: Route = {
      (get & path("multi"/"reverse"/Segment)) { name =>
        complete(name.reverse)
      }
    } ~ super.route
  }

} 
Example 184
Source File: SingleServiceSpec.scala    From akka-http-spring-boot   with Apache License 2.0 5 votes vote down vote up
package com.github.scalaspring.akka.http

import java.net.ServerSocket

import akka.http.scaladsl.client.RequestBuilding._
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
import com.github.scalaspring.scalatest.TestContextManagement
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.SpringApplicationContextLoader
import org.springframework.context.annotation.{Bean, Import}
import org.springframework.test.context.ContextConfiguration
import resource._

import scala.concurrent.duration._

@ContextConfiguration(
  loader = classOf[SpringApplicationContextLoader],
  classes = Array(classOf[SingleServiceSpec.Configuration])
)
class SingleServiceSpec extends FlatSpec with TestContextManagement with AkkaStreamsAutowiredImplicits with Matchers with ScalaFutures with StrictLogging {

  implicit val patience = PatienceConfig(10.seconds)    // Allow time for server startup

  @Autowired val settings: ServerSettings = null
  @Autowired val client: HttpClient = null

  "Echo service" should "echo" in {
    val name = "name"
    val future = client.request(Get(s"http://${settings.interface}:${settings.port}/single/echo/$name"))

    whenReady(future) { response =>
      //logger.info(s"""received response "$response"""")
      response.status shouldBe OK
      whenReady(Unmarshal(response.entity).to[String])(_ shouldBe name)
    }
  }

}


object SingleServiceSpec {

  @Configuration
  @Import(Array(classOf[AkkaHttpServerAutoConfiguration]))
  class Configuration extends AkkaHttpServer with EchoService {
    @Bean
    def serverSettings = new ServerSettings(port = managed(new ServerSocket(0)).map(_.getLocalPort).opt.get)
  }

  trait EchoService extends AkkaHttpService {
    abstract override def route: Route = {
      get {
        path("single"/ "echo" / Segment) { name =>
          complete(name)
        }
      }
    } ~ super.route
  }

} 
Example 185
Source File: RouteTestSpec.scala    From akka-http-spring-boot   with Apache License 2.0 5 votes vote down vote up
package com.github.scalaspring.akka.http

import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.testkit.ScalatestRouteTest
import com.github.scalaspring.akka.http.RouteTestSpec.EchoService
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}


class RouteTestSpec extends FlatSpec with EchoService with ScalatestRouteTest with Matchers with ScalaFutures with StrictLogging {

  "Echo service" should "echo" in {
    Get(s"/single/echo/hello") ~> route ~> check {
      status shouldBe OK
    }
  }

}


object RouteTestSpec {

  trait EchoService extends AkkaHttpService {
    abstract override def route: Route = {
      get {
        path("single"/ "echo" / Segment) { name =>
          complete(name)
        }
      }
    } ~ super.route
  }

} 
Example 186
Source File: TestAvroProducer.scala    From asura   with MIT License 5 votes vote down vote up
package asura.kafka.producer

import akka.actor.ActorSystem
import akka.kafka.ProducerSettings
import akka.kafka.scaladsl.Producer
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Source
import asura.kafka.avro.SampleAvroClass
import com.typesafe.scalalogging.StrictLogging
import io.confluent.kafka.serializers.{AbstractKafkaAvroSerDeConfig, KafkaAvroDeserializerConfig, KafkaAvroSerializer}
import org.apache.avro.specific.SpecificRecord
import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.kafka.common.serialization._

import scala.collection.JavaConverters._

// https://doc.akka.io/docs/alpakka-kafka/current/serialization.html
object TestAvroProducer extends StrictLogging {

  def main(args: Array[String]): Unit = {

    implicit val system = ActorSystem("producer")
    implicit val materializer = ActorMaterializer()
    implicit val ec = system.dispatcher

    val schemaRegistryUrl = ""
    val bootstrapServers = ""
    val topic = ""

    val kafkaAvroSerDeConfig = Map[String, Any](
      AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG -> schemaRegistryUrl,
      KafkaAvroDeserializerConfig.SPECIFIC_AVRO_READER_CONFIG -> true.toString
    )
    val producerSettings: ProducerSettings[String, SpecificRecord] = {
      val kafkaAvroSerializer = new KafkaAvroSerializer()
      kafkaAvroSerializer.configure(kafkaAvroSerDeConfig.asJava, false)
      val serializer = kafkaAvroSerializer.asInstanceOf[Serializer[SpecificRecord]]

      ProducerSettings(system, new StringSerializer, serializer)
        .withBootstrapServers(bootstrapServers)
    }

    val samples = (1 to 3).map(i => SampleAvroClass(s"key_$i", s"name_$i"))
    val done = Source(samples)
      .map(n => new ProducerRecord[String, SpecificRecord](topic, n.key, n))
      .runWith(Producer.plainSink(producerSettings))

    done onComplete {
      case scala.util.Success(_) => logger.info("Done"); system.terminate()
      case scala.util.Failure(err) => logger.error(err.toString); system.terminate()
    }
  }
} 
Example 187
Source File: TestProducer.scala    From asura   with MIT License 5 votes vote down vote up
package asura.kafka.producer

import akka.Done
import akka.actor.ActorSystem
import akka.kafka.ProducerSettings
import akka.kafka.scaladsl.Producer
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Source
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.kafka.common.serialization.StringSerializer

import scala.concurrent.Future

object TestProducer extends StrictLogging {

  def main(args: Array[String]): Unit = {

    logger.info("Start producer")

    implicit val system = ActorSystem("producer")
    implicit val materializer = ActorMaterializer()
    implicit val ec = system.dispatcher

    val producerSettings = ProducerSettings(system, new StringSerializer, new StringSerializer)
    val done: Future[Done] =
      Source(1 to 100)
        .map(value => new ProducerRecord[String, String]("test-topic", s"msg ${value}"))
        .runWith(Producer.plainSink(producerSettings))

    done onComplete {
      case scala.util.Success(_) => logger.info("Done"); system.terminate()
      case scala.util.Failure(err) => logger.error(err.toString); system.terminate()
    }
  }
} 
Example 188
Source File: TestConsumer.scala    From asura   with MIT License 5 votes vote down vote up
package asura.kafka.consumer

import akka.actor.ActorSystem
import akka.kafka.scaladsl.Consumer
import akka.kafka.{ConsumerSettings, Subscriptions}
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Sink
import com.typesafe.scalalogging.StrictLogging
import org.apache.kafka.common.serialization.StringDeserializer

object TestConsumer extends StrictLogging {

  def main(args: Array[String]): Unit = {

    logger.info("Start consumer")

    implicit val system = ActorSystem("consumer")
    implicit val materializer = ActorMaterializer()
    implicit val ec = system.dispatcher

    val consumerSettings = ConsumerSettings(system, new StringDeserializer, new StringDeserializer)
      .withGroupId("test-group1")

    val done = Consumer
      .plainSource(consumerSettings, Subscriptions.topics("test-topic"))
      .runWith(Sink.foreach(record =>
        logger.info(s"topic:${record.topic()}, partition:${record.partition()}, offset:${record.offset()}, key:${record.key()}, value: ${record.value()}"))
      )
    done onComplete {
      case scala.util.Success(_) => logger.info("Done"); system.terminate()
      case scala.util.Failure(err) => logger.error(err.toString); system.terminate()
    }
  }
} 
Example 189
Source File: TestAvroConsumer.scala    From asura   with MIT License 5 votes vote down vote up
package asura.kafka.consumer

import akka.actor.ActorSystem
import akka.kafka.scaladsl.Consumer
import akka.kafka.{ConsumerSettings, Subscriptions}
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{Keep, Sink}
import asura.kafka.avro.SampleAvroClass
import com.typesafe.scalalogging.StrictLogging
import io.confluent.kafka.serializers.{AbstractKafkaAvroSerDeConfig, KafkaAvroDeserializer, KafkaAvroDeserializerConfig}
import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.common.serialization._

import scala.collection.JavaConverters._

object TestAvroConsumer extends StrictLogging {

  def main(args: Array[String]): Unit = {

    implicit val system = ActorSystem("consumer")
    implicit val materializer = ActorMaterializer()
    implicit val ec = system.dispatcher

    val schemaRegistryUrl = ""
    val bootstrapServers = ""
    val topic = ""
    val group = ""

    val kafkaAvroSerDeConfig = Map[String, Any](
      AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG -> schemaRegistryUrl,
      KafkaAvroDeserializerConfig.SPECIFIC_AVRO_READER_CONFIG -> true.toString
    )
    val consumerSettings: ConsumerSettings[String, SampleAvroClass] = {
      val kafkaAvroDeserializer = new KafkaAvroDeserializer()
      kafkaAvroDeserializer.configure(kafkaAvroSerDeConfig.asJava, false)
      val deserializer = kafkaAvroDeserializer.asInstanceOf[Deserializer[SampleAvroClass]]

      ConsumerSettings(system, new StringDeserializer, deserializer)
        .withBootstrapServers(bootstrapServers)
        .withGroupId(group)
        .withProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
    }

    val samples = (1 to 3)
    val (control, result) = Consumer
      .plainSource(consumerSettings, Subscriptions.topics(topic))
      .take(samples.size.toLong)
      .map(_.value())
      .toMat(Sink.seq)(Keep.both)
      .run()

    control.shutdown()
    result.map(records => records.foreach(record => logger.info(s"${record}")))
  }
} 
Example 190
Source File: FileSystemAnnotationDataService.scala    From qamr   with MIT License 5 votes vote down vote up
package qamr

import spacro.util._

import scala.util.{Try, Success}
import java.nio.file.Path
import java.nio.file.Files

import com.typesafe.scalalogging.StrictLogging

class FileSystemAnnotationDataService(dataPath: Path) extends AnnotationDataService {

  private[this] def getDataDirectoryPath = Try {
    val directory = dataPath
    if(!Files.exists(directory)) {
      Files.createDirectories(directory)
    }
    directory
  }

  private[this] def getFullFilename(name: String) = s"$name.txt"

  override def saveLiveData(name: String, contents: String): Try[Unit] = for {
    directory <- getDataDirectoryPath
    _ <- Try(Files.write(directory.resolve(getFullFilename(name)), contents.getBytes()))
  } yield ()

  import scala.collection.JavaConverters._

  override def loadLiveData(name: String): Try[List[String]] = for {
    directory <- getDataDirectoryPath
    lines <- Try(Files.lines(directory.resolve(getFullFilename(name))).iterator.asScala.toList)
  } yield lines
} 
Example 191
Source File: MetronomeITBase.scala    From metronome   with Apache License 2.0 5 votes vote down vote up
package dcos.metronome.integrationtest

import java.util.UUID

import com.mesosphere.utils.AkkaUnitTest
import com.mesosphere.utils.http.RestResultMatchers
import com.mesosphere.utils.mesos.MesosClusterTest
import com.typesafe.scalalogging.StrictLogging
import dcos.metronome.integrationtest.utils.{MetronomeFacade, MetronomeFramework}
import org.apache.mesos.v1.Protos.FrameworkID
import org.scalatest.Inside

import scala.concurrent.duration._

class MetronomeITBase
    extends AkkaUnitTest
    with MesosClusterTest
    with Inside
    with RestResultMatchers
    with StrictLogging {

  override lazy implicit val patienceConfig = PatienceConfig(180.seconds, interval = 1.second)

  def withFixture(frameworkId: Option[FrameworkID.Builder] = None)(fn: Fixture => Unit): Unit = {
    val f = new Fixture(frameworkId)
    try fn(f)
    finally {
      f.metronomeFramework.stop()
    }
  }

  class Fixture(existingFrameworkId: Option[FrameworkID.Builder] = None) extends StrictLogging {
    logger.info("Create Fixture with new Metronome...")

    val zkUrl = s"zk://${zkserver.connectUrl}/metronome_${UUID.randomUUID()}"
    val masterUrl = mesosFacade.url.getHost + ":" + mesosFacade.url.getPort

    val currentITName = MetronomeITBase.this.getClass.getSimpleName

    val metronomeFramework = MetronomeFramework.LocalMetronome(currentITName, masterUrl, zkUrl)

    logger.info("Starting metronome...")
    metronomeFramework.start().futureValue

    logger.info(s"Metronome started, reachable on: ${metronomeFramework.url}")
    lazy val metronome: MetronomeFacade = metronomeFramework.facade
  }

} 
Example 192
Source File: HttpExecutionTest.scala    From maze   with Apache License 2.0 5 votes vote down vote up
package fr.vsct.dt.maze.helpers

import com.typesafe.scalalogging.StrictLogging
import fr.vsct.dt.maze.core.Commands.expectThat
import fr.vsct.dt.maze.core.Predef._
import fr.vsct.dt.maze.core.{Predicate, Result}
import org.apache.http._
import org.apache.http.client.methods.{CloseableHttpResponse, HttpGet}
import org.apache.http.entity.{ContentType, StringEntity}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.http.message.{BasicHttpResponse, BasicStatusLine}
import org.apache.http.protocol.HttpContext
import org.scalatest.FlatSpec

import scala.beans.BeanProperty


class HttpExecutionTest extends FlatSpec {

  
  class MockHttpClient(val response: String) extends CloseableHttpClient with StrictLogging {
    var init = false

    override def doExecute(target: HttpHost, request: HttpRequest, context: HttpContext): CloseableHttpResponse = {
      if (!init) throw new IllegalStateException("Client is not initialized")
      logger.info("Doing actual http call")
      val r = if(request.getRequestLine.getUri == "http://some-url.com") {
        val t = new BasicCloseableHttpResponse(new BasicStatusLine(HttpVersion.HTTP_1_1, HttpStatus.SC_OK, "OK"))
        t.setEntity(new StringEntity(response, "UTF-8"))
        t
      } else {
        val t = new BasicCloseableHttpResponse(new BasicStatusLine(HttpVersion.HTTP_1_1, HttpStatus.SC_BAD_REQUEST, "KO"))
        t.setEntity(new StringEntity("""{"status": "ko"}""", ContentType.APPLICATION_JSON))
        t
      }
      r
    }

    @Deprecated
    override def getConnectionManager = null

    @Deprecated
    override def getParams = null

    override def close(): Unit = {}

    class BasicCloseableHttpResponse(statusLine: StatusLine) extends BasicHttpResponse(statusLine) with CloseableHttpResponse {
      override def close(): Unit = {}
    }

  }

  "a http check" should "not do an effective call until apply is effectively called" in {

    Http.client = new MockHttpClient("Youppy !")

    val requestOk = new HttpGet("http://some-url.com")

    val check1: Predicate = Http.execute(requestOk).status is 200
    val check2: Predicate = Http.execute(requestOk).response is "Youppy !"

    val check3 = check1 || check2

    val check4 = !check3

    Http.client.asInstanceOf[MockHttpClient].init = true

    assert(check1.get() == Result.success)
    assert(check2.get() == Result.success)
    assert(check3.get() == Result.success)
    assert(check4.get() == Result.failure(s"Expected ${check3.label} to be false"))
    expectThat(Http.get("http://some-error-url.com").status is 400)
    expectThat(Http.get("http://some-url.com").isOk)
    expectThat(!Http.get("http://some-error-url.com").isOk)
    expectThat(Http.get("http://some-error-url.com").responseAs(classOf[Stupid]) is Stupid(status = "ko"))

  }

}

case class Stupid(@BeanProperty status: String) 
Example 193
Source File: MergeHubDemo.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package example.akkastream.dynamichub

import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{ MergeHub, RunnableGraph, Sink, Source }
import com.typesafe.scalalogging.StrictLogging

import scala.io.StdIn

object MergeHubDemo extends App with StrictLogging {
  implicit val system = ActorSystem()
  implicit val mat = ActorMaterializer()

  // A simple consumer that will print to the console for now
  val consumer = Sink.foreach[String](v => logger.info(s"consumer: $v"))

  // Attach a MergeHub Source to the consumer. This will materialize to a
  // corresponding Sink.
  val runnableGraph: RunnableGraph[Sink[String, NotUsed]] =
    MergeHub.source[String](perProducerBufferSize = 16).to(consumer)

  // By running/materializing the consumer we get back a Sink, and hence
  // now have access to feed elements into it. This Sink can be materialized
  // any number of times, and every element that enters the Sink will
  // be consumed by our consumer.
  val toConsumer: Sink[String, NotUsed] = runnableGraph.run()

  // Feeding two independent sources into the hub.
  Source.single("Hello!").runWith(toConsumer)
  Source.single("Hub!").runWith(toConsumer)

  StdIn.readLine()
  system.terminate()
} 
Example 194
Source File: SimplePublishSubscribe.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package example.akkastream.dynamichub

import akka.actor.ActorSystem
import akka.stream.{ ActorMaterializer, KillSwitches, UniqueKillSwitch }
import akka.stream.scaladsl.{ BroadcastHub, Flow, Keep, MergeHub, Sink, Source }
import com.typesafe.scalalogging.StrictLogging

import scala.io.StdIn
import scala.concurrent.duration._

object SimplePublishSubscribe extends App with StrictLogging {
  implicit val system = ActorSystem()
  implicit val mat = ActorMaterializer()
  import system.dispatcher

  val (sink, source) =
    MergeHub.source[String](perProducerBufferSize = 16).toMat(BroadcastHub.sink(bufferSize = 256))(Keep.both).run()

  source.runWith(Sink.ignore)

  val busFlow: Flow[String, String, UniqueKillSwitch] = Flow
    .fromSinkAndSource(sink, source)
    .joinMat(KillSwitches.singleBidi[String, String])(Keep.right)
    .backpressureTimeout(3.seconds)

  val switch: UniqueKillSwitch =
    Source.repeat("Hello world!").viaMat(busFlow)(Keep.right).to(Sink.foreach(v => logger.info(s"switch: $v"))).run()

  Thread.sleep(200)
  switch.shutdown()

  StdIn.readLine()
  system.terminate()
} 
Example 195
Source File: ConnectorSystem.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.connector

import java.nio.file.Path

import akka.Done
import akka.actor.ExtendedActorSystem
import com.typesafe.scalalogging.StrictLogging
import fusion.common.extension.{ FusionExtension, FusionExtensionId }
import fusion.core.extension.FusionCore
import mass.core.Constants

import scala.concurrent.Future
import scala.util.{ Failure, Success }

final class ConnectorSystem private (override val classicSystem: ExtendedActorSystem)
    extends FusionExtension
    with StrictLogging {
  private var _parsers = Map.empty[String, ConnectorParser]
  private var _connectors = Map.empty[String, Connector]

  init()

  private def init(): Unit = {
    configuration.get[Seq[String]](s"${Constants.BASE_CONF}.connector.parsers").foreach { className =>
      classicSystem.dynamicAccess.createInstanceFor[ConnectorParser](className, Nil) match {
        case Success(parse) => registerConnectorParser(parse)
        case Failure(e)     => logger.error(s"未知的ConnectorParse", e)
      }
    }
    FusionCore(classicSystem).shutdowns.serviceUnbind("ConnectorSystem") { () =>
      Future {
        connectors.foreach { case (_, c) => c.close() }
        Done
      }(classicSystem.dispatcher)
    }
  }

  def name: String = classicSystem.name

  def getConnector(name: String): Option[Connector] = _connectors.get(name)

  def connectors: Map[String, Connector] = _connectors

  def registerConnector(c: Connector): Map[String, Connector] = {
    _connectors = _connectors.updated(c.name, c)
    _connectors
  }

  def parsers: Map[String, ConnectorParser] = _parsers

  def registerConnectorParser(parse: ConnectorParser): Map[String, ConnectorParser] = {
    _parsers = _parsers.updated(parse.`type`, parse)
    logger.info(s"注册Connector解析器:$parse,当前数量:${parsers.size}")
    parsers
  }

  def fromFile(path: Path): Option[Connector] = ???

  def fromXML(node: scala.xml.Node): Option[Connector] = {
    import mass.core.XmlUtils.XmlRich
    val maybeParser = parsers.get(node.attr("type"))
    maybeParser.map(cp => cp.parseFromXML(node))
  }
}

object ConnectorSystem extends FusionExtensionId[ConnectorSystem] {
  override def createExtension(system: ExtendedActorSystem): ConnectorSystem = new ConnectorSystem(system)
} 
Example 196
Source File: RdpSystem.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.rdp

import akka.actor.ExtendedActorSystem
import akka.stream.Materializer
import com.typesafe.scalalogging.StrictLogging
import fusion.common.extension.{ FusionExtension, FusionExtensionId }
import mass.connector.ConnectorSystem
import mass.core.Constants
import mass.extension.MassCore
import mass.rdp.etl.graph.{ EtlGraphParserFactory, EtlStreamFactory }
import mass.rdp.module.RdpModule

import scala.util.{ Failure, Success }

trait RdpRefFactory {
  def settings: MassCore

  def connectorSystem: ConnectorSystem
}

private[rdp] class RdpSetup(val system: ExtendedActorSystem) extends StrictLogging {
  val massCore = MassCore(system)

  val extensions: Vector[RdpModule] =
    massCore.configuration
      .get[Seq[String]](s"${Constants.BASE_CONF}.rdp.extensions")
      .flatMap { className =>
        system.dynamicAccess.createInstanceFor[RdpModule](className, Nil) match {
          case Success(v) => Some(v)
          case Failure(e) =>
            logger.warn(s"初始化找到未知RdpExtension", e)
            None
        }
      }
      .toVector

  def initialStreamFactories(): Map[String, EtlStreamFactory] = {
    val list = extensions.flatMap(_.etlStreamBuilders) ++
      massCore.configuration.get[Seq[String]](s"${Constants.BASE_CONF}.rdp.stream-builders").flatMap { className =>
        system.dynamicAccess.createInstanceFor[EtlStreamFactory](className, Nil) match {
          case Success(v) => Some(v)
          case Failure(e) =>
            logger.warn(s"初始化找到未知EtlStreamBuilder", e)
            None
        }
      }
    list.map(v => v.`type` -> v).toMap
  }

  def initialGraphParserFactories(): Map[String, EtlGraphParserFactory] =
    extensions.flatMap(_.graphParserFactories).map(v => v.`type` -> v).toMap
}


final class RdpSystem private (override val classicSystem: ExtendedActorSystem)
    extends RdpRefFactory
    with FusionExtension
    with StrictLogging {
  override val connectorSystem: ConnectorSystem = ConnectorSystem(classicSystem)
  implicit val materializer: Materializer = Materializer.matFromSystem(classicSystem)

  private val setup = new RdpSetup(classicSystem)

  protected var _streamFactories: Map[String, EtlStreamFactory] = setup.initialStreamFactories()

  protected var _graphParerFactories: Map[String, EtlGraphParserFactory] = setup.initialGraphParserFactories()

  def streamFactories: Map[String, EtlStreamFactory] = _streamFactories

  def registerSourceBuilder(b: EtlStreamFactory): Unit = {
    logger.info(s"注册EtlSourceBuilder: $b")
    _streamFactories = _streamFactories.updated(b.`type`, b)
  }

  def graphParserFactories: Map[String, EtlGraphParserFactory] =
    _graphParerFactories

  def registerGraphParserFactories(b: EtlGraphParserFactory): Unit = {
    logger.info(s"注册EtlGraphParserFactor: $b")
    _graphParerFactories = _graphParerFactories.updated(b.`type`, b)
  }
  override def settings: MassCore = setup.massCore
  def name: String = classicSystem.name
}

object RdpSystem extends FusionExtensionId[RdpSystem] {
  override def createExtension(system: ExtendedActorSystem): RdpSystem = new RdpSystem(system)
} 
Example 197
Source File: EtlGraphImpl.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.rdp.etl.graph

import akka.NotUsed
import akka.stream.scaladsl.{ Sink, Source }
import com.typesafe.scalalogging.StrictLogging
import javax.script.SimpleBindings
import mass.connector.Connector
import mass.connector.sql._
import mass.core.event.{ EventData, EventDataSimple }
import mass.core.script.ScriptManager
import mass.rdp.RdpSystem
import mass.rdp.etl.{ EtlResult, EtlWorkflowExecution, SqlEtlResult }

import scala.collection.immutable
import scala.concurrent.{ Future, Promise }
import scala.util.{ Failure, Success }

case class EtlGraphImpl(graphSetting: EtlGraphSetting) extends EtlGraph with StrictLogging {
  override def run(connectors: immutable.Seq[Connector], rdpSystem: RdpSystem): EtlWorkflowExecution = {
    implicit val ec = rdpSystem.materializer.system.dispatcher
    implicit val mat = rdpSystem.materializer

    def getConnector(name: String): Connector =
      connectors.find(_.name == name) orElse
      rdpSystem.connectorSystem.getConnector(name) getOrElse
      (throw new EtlGraphException(s"connector ref: $name 不存在"))

    val promise = Promise[EtlResult]()

    val source = dataSource(getConnector(graphSource.connector.ref), rdpSystem)
    val sink = dataSink(getConnector(graphSink.connector.ref), rdpSystem)

    graphFlows
      .foldLeft(source)((s, etlFlow) =>
        s.map { event =>
          val engine = ScriptManager.scriptJavascript
          val bindings = new SimpleBindings()
          bindings.put("event", event.asInstanceOf[EventDataSql])
          val data = engine.eval(etlFlow.script.content.get, bindings)

          // TODO 在此可设置是否发送通知消息给在线监控系统
          logger.debug(s"engine: $engine, event: $event, result data: $data")

          EventDataSimple(data)
        })
      .runWith(sink)
      .onComplete {
        case Success(result) => promise.success(SqlEtlResult(result))
        case Failure(e)      => promise.failure(e)
      }

    new EtlWorkflowExecution(promise, () => ())
  }

  private def dataSource(connector: Connector, rdpSystem: RdpSystem): Source[EventData, NotUsed] =
    rdpSystem.streamFactories.get(connector.`type`.toString) match {
      case Some(b) => b.buildSource(connector, graphSource)
      case _       => throw new EtlGraphException(s"未知Connector: $connector")
    }

  private def dataSink(connector: Connector, rdpSystem: RdpSystem): Sink[EventData, Future[JdbcSinkResult]] =
    rdpSystem.streamFactories.get(connector.`type`.toString) match {
      case Some(b) => b.buildSink(connector, graphSink)
      case _       => throw new EtlGraphException(s"未知Connector: $connector")
    }
} 
Example 198
Source File: EtlGraphParser.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.rdp.etl.graph

import com.typesafe.scalalogging.StrictLogging
import helloscala.common.util.StringUtils
import mass.core.XmlUtils

import scala.util.Try
import scala.xml.NodeSeq

trait EtlGraphParser {
  def parse(): Try[EtlGraphSetting]

  def validation(setting: EtlGraphSetting): Try[EtlGraphSetting] = Try {
    val sourceOut = setting.source.out
    val sinkName = setting.sink.name
    if (!(setting.flows.exists(_.name == sourceOut) || sinkName == sourceOut)) {
      throw new EtlGraphException("source.out未找到指定的flow或sink")
    }

    if (!(setting.flows.exists(_.outs.exists(_ == sinkName)) || sourceOut == sinkName)) {
      throw new EtlGraphException("graph不是闭合的")
    }

    // TODO 其它 graph 校验

    setting
  }
}

trait EtlGraphParserFactory {
  def `type`: String
}

class EtlGraphXmlParserFactory extends EtlGraphParserFactory {
  override def `type`: String = "xml"

  def build(elem: NodeSeq): EtlGraphParser = new EtlGraphXmlParser(elem)

  class EtlGraphXmlParser(elem: NodeSeq) extends EtlGraphParser with StrictLogging {
    import mass.core.XmlUtils.XmlRich

    logger.trace(s"parse elem:\n$elem")

    def parse(): Try[EtlGraphSetting] = {
      val name = elem.attr("name")
      require(StringUtils.isNoneBlank(name), s"graph需要设置id属性:$elem")

      val source = parseSource(elem \ "source")
      val flows = (elem \ "flows" \ "flow").map(parseFlow).toVector
      val sink = parseSink(elem \ "sink")

      validation(EtlGraphSetting(name, source, flows, sink))
    }

    private def parseSource(node: NodeSeq): EtlSource = {
      val name = node.attr("name")
      val connector = parseConnector(node \ "connector")
      val script = parseScript(node \ "script")
      val out = XmlUtils.text(node \ "out")
      EtlSource(name, connector, script, out)
    }

    private def parseFlow(node: NodeSeq): EtlFlow = {
      val name = node.attr("name")
      val script = parseScript(node \ "script")
      val outs = (node \ "out").map(XmlUtils.text).toVector
      EtlFlow(name, script, outs)
    }

    private def parseSink(node: NodeSeq): EtlSink = {
      val name = node.attr("name")
      val connector = parseConnector(node \ "connector")
      val script = parseScript(node \ "script")
      EtlSink(name, connector, script)
    }

    @inline private def parseScript(node: NodeSeq): EtlScript = {
      logger.trace(s"parse script:\n$node")
      EtlScript(EtlScriptType.withName(node.attr("type")), node.getAttr("src"), node.getText)
    }

    @inline private def parseConnector(node: NodeSeq): EtlConnector =
      EtlConnector(node.attr("ref"))
  }
} 
Example 199
Source File: DefaultSchedulerJob.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.job.component

import java.nio.file.Paths

import com.typesafe.scalalogging.StrictLogging
import fusion.inject.guice.GuiceApplication
import mass.core.job._
import mass.job.JobScheduler
import mass.message.job.SchedulerJobResult

import scala.concurrent.Future

class DefaultSchedulerJob extends SchedulerJob with StrictLogging {
  override def run(context: SchedulerContext): Future[JobResult] = {
    val jobScheduler = GuiceApplication(context.system).instance[JobScheduler]
    // TODO Use job blocking dispatcher
    val blockingDispatcher = jobScheduler.executionContext
    Future {
      context.jobItem.resources.get(JobConstants.Resources.ZIP_PATH) match {
        case Some(zipPath) => handleZip(zipPath, jobScheduler, context)
        case _             => handle(jobScheduler, context)
      }
    }(blockingDispatcher)
  }

  private def handleZip(zipPath: String, jobSystem: JobScheduler, ctx: SchedulerContext): SchedulerJobResult =
    JobRun.runOnZip(Paths.get(zipPath), ctx.key, ctx.jobItem, jobSystem.jobSettings)

  private def handle(jobSystem: JobScheduler, ctx: SchedulerContext): SchedulerJobResult =
    JobRun.run(ctx.jobItem, ctx.key, jobSystem.jobSettings)
} 
Example 200
Source File: JobRoute.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.job.route.api.v1

import akka.actor.ActorSystem
import akka.http.scaladsl.server.Route
import com.typesafe.scalalogging.StrictLogging
import fusion.http.server.{ AbstractRoute, JacksonDirectives }
import fusion.json.jackson.http.JacksonSupport
import javax.inject.{ Inject, Singleton }
import mass.common.page.Page
import mass.extension.MassCore
import mass.job.service.job.JobService
import mass.message.job._

@Singleton
class JobRoute @Inject() (jobService: JobService, val jacksonSupport: JacksonSupport)(implicit system: ActorSystem)
    extends AbstractRoute
    with JacksonDirectives
    with StrictLogging {
  private val pagePDM = ('page.as[Int].?(Page.DEFAULT_PAGE), Symbol("size").as[Int].?(Page.DEFAULT_SIZE), 'key.?)

  override def route: Route = pathPrefix("job") {
    createJobRoute ~
    updateJobRoute ~
    pageRoute ~
    itemByKeyRoute ~
    uploadJobZipRoute ~
    optionAllRoute ~
    uploadFileRoute
  }

  def createJobRoute: Route = pathPost("create") {
    entity(jacksonAs[JobCreateReq]) { req =>
      futureComplete(jobService.createJob(req))
    }
  }

  def updateJobRoute: Route = pathPost("update") {
    entity(jacksonAs[JobUpdateReq]) { req =>
      futureComplete(jobService.updateJob(req))
    }
  }

  def itemByKeyRoute: Route = pathGet("item" / Segment) { key =>
    futureComplete(jobService.findItemByKey(key))
  }

  def pageRoute: Route = pathGet("page") {
    parameters(pagePDM).as(JobPageReq.apply _) { req =>
      futureComplete(jobService.page(req))
    }
  }

  def uploadJobZipRoute: Route = pathPost("upload_job") {
    extractExecutionContext { implicit ec =>
      storeUploadedFile("job", createTempFileFunc(MassCore(system).tempDirectory)) {
        case (fileInfo, file) =>
          futureComplete(jobService.uploadJobOnZip(fileInfo, file.toPath))
      }
    }
  }

  def uploadFileRoute: Route = pathPost("upload_file") {
    extractExecutionContext { implicit ec =>
      uploadedFiles(createTempFileFunc(MassCore(system).tempDirectory)) { list =>
        futureComplete(jobService.uploadFiles(list))
      }
    }
  }

  def optionAllRoute: Route = pathGet("option_all") {
    futureComplete(jobService.listOption())
  }
}