java.sql.Connection Scala Examples

The following examples show how to use java.sql.Connection. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: Queries.scala    From daml   with Apache License 2.0 7 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.io.InputStream
import java.sql.{Blob, Connection, PreparedStatement}

import anorm.{
  BatchSql,
  Column,
  MetaDataItem,
  NamedParameter,
  RowParser,
  SqlMappingError,
  SqlParser,
  SqlRequestError,
  ToStatement
}
import com.google.protobuf.ByteString

trait Queries extends ReadQueries with WriteQueries

object Queries {
  val TablePrefix = "ledger"
  val LogTable = s"${TablePrefix}_log"
  val MetaTable = s"${TablePrefix}_meta"
  val StateTable = s"${TablePrefix}_state"

  // By explicitly writing a value to a "table_key" column, we ensure we only ever have one row in
  // the meta table. An attempt to write a second row will result in a key conflict.
  private[queries] val MetaTableKey = 0

  def executeBatchSql(
      query: String,
      params: Iterable[Seq[NamedParameter]],
  )(implicit connection: Connection): Unit = {
    if (params.nonEmpty)
      BatchSql(query, params.head, params.drop(1).toArray: _*).execute()
    ()
  }

  implicit def byteStringToStatement: ToStatement[ByteString] = new ToStatement[ByteString] {
    override def set(s: PreparedStatement, index: Int, v: ByteString): Unit =
      s.setBinaryStream(index, v.newInput(), v.size())
  }

  implicit def columnToByteString: Column[ByteString] =
    Column.nonNull { (value: Any, meta: MetaDataItem) =>
      value match {
        case blob: Blob => Right(ByteString.readFrom(blob.getBinaryStream))
        case byteArray: Array[Byte] => Right(ByteString.copyFrom(byteArray))
        case inputStream: InputStream => Right(ByteString.readFrom(inputStream))
        case _ =>
          Left[SqlRequestError, ByteString](
            SqlMappingError(s"Cannot convert value of column ${meta.column} to ByteString"))
      }
    }

  def getBytes(columnName: String): RowParser[ByteString] =
    SqlParser.get(columnName)(columnToByteString)

} 
Example 2
Source File: TagInputAssociation.scala    From smui   with Apache License 2.0 7 votes vote down vote up
package models

import java.sql.Connection
import java.time.LocalDateTime

import anorm.SqlParser.get
import anorm._

case class TagInputAssociation(tagId: InputTagId,
                               searchInputId: SearchInputId,
                               lastUpdate: LocalDateTime = LocalDateTime.now()) {

  import TagInputAssociation._

  def toNamedParameters: Seq[NamedParameter] = Seq(
    TAG_ID -> tagId,
    INPUT_ID -> searchInputId,
    LAST_UPDATE -> lastUpdate
  )

}

object TagInputAssociation {

  val TABLE_NAME = "tag_2_input"

  val TAG_ID = "tag_id"
  val INPUT_ID = "input_id"
  val LAST_UPDATE = "last_update"

  def insert(associations: TagInputAssociation*)(implicit connection: Connection): Unit = {
    if (associations.nonEmpty) {
      BatchSql(s"insert into $TABLE_NAME ($TAG_ID, $INPUT_ID, $LAST_UPDATE) " +
        s"values ({$TAG_ID}, {$INPUT_ID}, {$LAST_UPDATE})",
        associations.head.toNamedParameters,
        associations.tail.map(_.toNamedParameters): _*
      ).execute()
    }
  }

  
  def updateTagsForSearchInput(searchInputId: SearchInputId, tagIds: Seq[InputTagId])(implicit connection: Connection): Unit = {
    deleteBySearchInputId(searchInputId)
    insert(tagIds.map(tagId => TagInputAssociation(tagId, searchInputId)): _*)
  }

  def loadTagsBySearchInputId(id: SearchInputId)(implicit connection: Connection): Seq[InputTag] = {
    SQL(s"select * from $TABLE_NAME a, ${InputTag.TABLE_NAME} t where a.$INPUT_ID = {inputId} " +
      s"and a.$TAG_ID = t.${InputTag.ID} order by t.${InputTag.PROPERTY} asc, t.${InputTag.VALUE} asc").
      on("inputId" -> id).as(InputTag.sqlParser.*)
  }

  def loadTagsBySearchInputIds(ids: Seq[SearchInputId])(implicit connection: Connection): Map[SearchInputId, Seq[InputTag]] = {
    ids.grouped(100).toSeq.flatMap { idGroup =>
      SQL(s"select * from $TABLE_NAME a, ${InputTag.TABLE_NAME} t where a.$INPUT_ID in ({inputIds}) " +
        s"and a.$TAG_ID = t.${InputTag.ID} order by t.${InputTag.PROPERTY} asc, t.${InputTag.VALUE} asc").
        on("inputIds" -> idGroup).as((InputTag.sqlParser ~ get[SearchInputId](s"$TABLE_NAME.$INPUT_ID")).*).
        map { case tag ~ inputId =>
          inputId -> tag
        }
    }.groupBy(_._1).mapValues(_.map(_._2))
  }

  def deleteBySearchInputId(id: SearchInputId)(implicit connection: Connection): Int = {
    SQL"delete from #$TABLE_NAME where #$INPUT_ID = $id".executeUpdate()
  }

} 
Example 3
Source File: GenericConnectionPool.scala    From airframe   with Apache License 2.0 6 votes vote down vote up
package wvlet.airframe.jdbc
import java.sql.Connection

import com.zaxxer.hikari.{HikariConfig, HikariDataSource}


class GenericConnectionPool(val config: DbConfig) extends ConnectionPool {

  protected val dataSource: HikariDataSource = {
    val connectionPoolConfig = new HikariConfig

    // Set default JDBC parameters
    connectionPoolConfig.setMaximumPoolSize(config.connectionPool.maxPoolSize) // HikariCP default = 10
    connectionPoolConfig.setAutoCommit(config.connectionPool.autoCommit)       // Enable auto-commit

    connectionPoolConfig.setDriverClassName(config.jdbcDriverName)
    config.user.foreach(u => connectionPoolConfig.setUsername(u))
    config.password.foreach(p => connectionPoolConfig.setPassword(p))

    config.`type` match {
      case "postgresql" =>
        if (config.postgres.useSSL) {
          connectionPoolConfig.addDataSourceProperty("ssl", "true")
          connectionPoolConfig.addDataSourceProperty("sslfactory", config.postgres.sslFactory)
        }
    }

    if (config.host.isEmpty) {
      throw new IllegalArgumentException(s"missing jdbc host: ${config}")
    }

    connectionPoolConfig.setJdbcUrl(config.jdbcUrl)

    info(s"jdbc URL: ${connectionPoolConfig.getJdbcUrl}")
    new HikariDataSource(config.connectionPool.hikariConfig(connectionPoolConfig))
  }

  override def withConnection[U](body: Connection => U): U = {
    val conn = dataSource.getConnection
    try {
      body(conn)
    } finally {
      // Return the connection to the pool
      conn.close
    }
  }

  override def stop: Unit = {
    info(s"Closing the connection pool for ${config.jdbcUrl}")
    dataSource.close()
  }
} 
Example 4
Source File: DNSstat.scala    From jdbcsink   with Apache License 2.0 6 votes vote down vote up
import org.apache.spark.sql.SparkSession
import java.util.Properties
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.{from_json,window}
import java.sql.{Connection,Statement,DriverManager}
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.Row

class JDBCSink() extends ForeachWriter[Row]{
 val driver = "com.mysql.jdbc.Driver"
      var connection:Connection = _
      var statement:Statement = _

    def open(partitionId: Long,version: Long): Boolean = {
        Class.forName(driver)
        connection = DriverManager.getConnection("jdbc:mysql://10.88.1.102:3306/aptwebservice", "root", "mysqladmin")
        statement = connection.createStatement
        true
      }
      def process(value: Row): Unit = {
        statement.executeUpdate("replace into DNSStat(ip,domain,time,count) values(" 
                                    + "'" + value.getString(0) + "'" + ","//ip
                                    + "'" + value.getString(1) + "'" + ","//domain
                                    + "'" + value.getTimestamp(2) + "'" + "," //time
                                    + value.getLong(3) //count
                                    + ")") 
      }

      def close(errorOrNull: Throwable): Unit = {
        connection.close
      }
}

object DNSstatJob{

val schema: StructType = StructType(
        Seq(StructField("Vendor", StringType,true),
         StructField("Id", IntegerType,true),
         StructField("Time", LongType,true),
         StructField("Conn", StructType(Seq(
                                        StructField("Proto", IntegerType, true), 
                                        StructField("Sport", IntegerType, true), 
                                        StructField("Dport", IntegerType, true), 
                                        StructField("Sip", StringType, true), 
                                        StructField("Dip", StringType, true)
                                        )), true),
        StructField("Dns", StructType(Seq(
                                        StructField("Domain", StringType, true), 
                                        StructField("IpCount", IntegerType, true), 
                                        StructField("Ip", StringType, true) 
                                        )), true)))

    def main(args: Array[String]) {
    val spark=SparkSession
          .builder
          .appName("DNSJob")
          .config("spark.some.config.option", "some-value")
          .getOrCreate()
    import spark.implicits._
    val connectionProperties = new Properties()
    connectionProperties.put("user", "root")
    connectionProperties.put("password", "mysqladmin")
    val bruteForceTab = spark.read
                .jdbc("jdbc:mysql://10.88.1.102:3306/aptwebservice", "DNSTab",connectionProperties)
    bruteForceTab.registerTempTable("DNSTab")
    val lines = spark
          .readStream
          .format("kafka")
          .option("kafka.bootstrap.servers", "10.94.1.110:9092")
          .option("subscribe","xdr")
          //.option("startingOffsets","earliest")
          .option("startingOffsets","latest")
          .load()
          .select(from_json($"value".cast(StringType),schema).as("jsonData"))
    lines.registerTempTable("xdr")
    val filterDNS = spark.sql("select CAST(from_unixtime(xdr.jsonData.Time DIV 1000000) as timestamp) as time,xdr.jsonData.Conn.Sip as sip, xdr.jsonData.Dns.Domain from xdr inner join DNSTab on xdr.jsonData.Dns.domain = DNSTab.domain")
    
    val windowedCounts = filterDNS
                        .withWatermark("time","5 minutes")
                        .groupBy(window($"time", "1 minutes", "1 minutes"),$"sip",$"domain")
                        .count()
                        .select($"sip",$"domain",$"window.start",$"count")

    val writer = new JDBCSink()
    val query = windowedCounts
       .writeStream
        .foreach(writer)
        .outputMode("update")
        .option("checkpointLocation","/checkpoint/")
        .start()
        query.awaitTermination() 
   } 
} 
Example 5
Source File: CommonQueries.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.sql.Connection

import anorm.SqlParser._
import anorm._
import com.daml.ledger.on.sql.Index
import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.kvutils.KVOffset
import com.daml.ledger.participant.state.kvutils.api.LedgerRecord
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}

import scala.collection.{breakOut, immutable}
import scala.util.Try

trait CommonQueries extends Queries {
  protected implicit val connection: Connection

  override final def selectLatestLogEntryId(): Try[Option[Index]] = Try {
    SQL"SELECT MAX(sequence_no) max_sequence_no FROM #$LogTable"
      .as(get[Option[Long]]("max_sequence_no").singleOpt)
      .flatten
  }

  override final def selectFromLog(
      startExclusive: Index,
      endInclusive: Index,
  ): Try[immutable.Seq[(Index, LedgerRecord)]] = Try {
    SQL"SELECT sequence_no, entry_id, envelope FROM #$LogTable WHERE sequence_no > $startExclusive AND sequence_no <= $endInclusive ORDER BY sequence_no"
      .as((long("sequence_no") ~ getBytes("entry_id") ~ getBytes("envelope")).map {
        case index ~ entryId ~ envelope =>
          index -> LedgerRecord(KVOffset.fromLong(index), entryId, envelope)
      }.*)
  }

  override final def selectStateValuesByKeys(keys: Seq[Key]): Try[immutable.Seq[Option[Value]]] =
    Try {
      val results =
        SQL"SELECT key, value FROM #$StateTable WHERE key IN ($keys)"
          .fold(Map.newBuilder[Key, Value], ColumnAliaser.empty) { (builder, row) =>
            builder += row("key") -> row("value")
          }
          .fold(exceptions => throw exceptions.head, _.result())
      keys.map(results.get)(breakOut)
    }

  override final def updateState(stateUpdates: Seq[(Key, Value)]): Try[Unit] = Try {
    executeBatchSql(updateStateQuery, stateUpdates.map {
      case (key, value) =>
        Seq[NamedParameter]("key" -> key, "value" -> value)
    })
  }

  protected val updateStateQuery: String
} 
Example 6
Source File: SqliteQueries.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.sql.Connection

import anorm.SqlParser._
import anorm._
import com.daml.ledger.on.sql.Index
import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}

import scala.util.Try

final class SqliteQueries(override protected implicit val connection: Connection)
    extends Queries
    with CommonQueries {
  override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): Try[LedgerId] = Try {
    SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING"
      .executeInsert()
    SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
      .as(str("ledger_id").single)
  }

  override def insertRecordIntoLog(key: Key, value: Value): Try[Index] =
    Try {
      SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)"
        .executeInsert()
      ()
    }.flatMap(_ => lastInsertId())

  override protected val updateStateQuery: String =
    s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}"

  private def lastInsertId(): Try[Index] = Try {
    SQL"SELECT LAST_INSERT_ROWID() AS row_id"
      .as(long("row_id").single)
  }

  override final def truncate(): Try[Unit] = Try {
    SQL"delete from #$StateTable".executeUpdate()
    SQL"delete from #$LogTable".executeUpdate()
    SQL"delete from #$MetaTable".executeUpdate()
    ()
  }
}

object SqliteQueries {
  def apply(connection: Connection): Queries = {
    implicit val conn: Connection = connection
    new SqliteQueries
  }
} 
Example 7
Source File: PostgresqlQueries.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.sql.Connection

import anorm.SqlParser._
import anorm._
import com.daml.ledger.on.sql.Index
import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}

import scala.util.Try

final class PostgresqlQueries(override protected implicit val connection: Connection)
    extends Queries
    with CommonQueries {
  override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): Try[LedgerId] = Try {
    SQL"INSERT INTO #$MetaTable (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId) ON CONFLICT DO NOTHING"
      .executeInsert()
    SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
      .as(str("ledger_id").single)
  }

  override def insertRecordIntoLog(key: Key, value: Value): Try[Index] = Try {
    SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value) RETURNING sequence_no"
      .as(long("sequence_no").single)
  }

  override protected val updateStateQuery: String =
    s"INSERT INTO $StateTable VALUES ({key}, {value}) ON CONFLICT(key) DO UPDATE SET value = {value}"

  override final def truncate(): Try[Unit] = Try {
    SQL"truncate #$StateTable".executeUpdate()
    SQL"truncate #$LogTable".executeUpdate()
    SQL"truncate #$MetaTable".executeUpdate()
    ()
  }
}

object PostgresqlQueries {
  def apply(connection: Connection): Queries = {
    implicit val conn: Connection = connection
    new PostgresqlQueries
  }
} 
Example 8
Source File: V10_1__Populate_Event_Data.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// Note: package name must correspond exactly to the flyway 'locations' setting, which defaults to
// 'db.migration.postgres' for postgres migrations
package db.migration.postgres

import java.sql.Connection

import anorm.{BatchSql, NamedParameter}
import com.daml.lf.transaction.{Transaction => Tx}
import com.daml.lf.transaction.Node.NodeCreate
import com.daml.ledger.EventId
import com.daml.lf.data.Ref
import com.daml.platform.store.Conversions._
import db.migration.translation.TransactionSerializer
import org.flywaydb.core.api.migration.{BaseJavaMigration, Context}

class V10_1__Populate_Event_Data extends BaseJavaMigration {

  val SELECT_TRANSACTIONS =
    "select distinct le.transaction_id, le.transaction from contracts c join ledger_entries le  on c.transaction_id = le.transaction_id"

  def loadTransactions(conn: Connection) = {
    val statement = conn.createStatement()
    val rows = statement.executeQuery(SELECT_TRANSACTIONS)

    new Iterator[(Ref.LedgerString, Tx.Transaction)] {
      var hasNext: Boolean = rows.next()

      def next(): (Ref.LedgerString, Tx.Transaction) = {
        val transactionId = Ref.LedgerString.assertFromString(rows.getString("transaction_id"))
        val transaction = TransactionSerializer
          .deserializeTransaction(transactionId, rows.getBinaryStream("transaction"))
          .getOrElse(sys.error(s"failed to deserialize transaction $transactionId"))

        hasNext = rows.next()
        if (!hasNext) {
          statement.close()
        }

        transactionId -> transaction
      }
    }
  }

  private val batchSize = 10 * 1000

  override def migrate(context: Context): Unit = {
    val conn = context.getConnection

    val txs = loadTransactions(conn)
    val data = txs.flatMap {
      case (txId, tx) =>
        tx.nodes.collect {
          case (nodeId, NodeCreate(cid, _, _, signatories, stakeholders, _)) =>
            (cid, EventId(txId, nodeId), signatories, stakeholders -- signatories)
        }
    }

    data.grouped(batchSize).foreach { batch =>
      val updateContractsParams = batch.map {
        case (cid, eventId, _, _) =>
          Seq[NamedParameter]("event_id" -> eventId, "contract_id" -> cid.coid)
      }
      BatchSql(
        "UPDATE contracts SET create_event_id = {event_id} where id = {contract_id}",
        updateContractsParams.head,
        updateContractsParams.tail: _*
      ).execute()(conn)

      val signatories = batch.flatMap {
        case (cid, _, signatories, _) =>
          signatories.map(signatory =>
            Seq[NamedParameter]("contract_id" -> cid.coid, "party" -> signatory))
      }
      BatchSql(
        "INSERT INTO contract_signatories VALUES ({contract_id}, {party})",
        signatories.head,
        signatories.tail: _*
      ).execute()(conn)

      val observers = batch.flatMap {
        case (cid, _, _, observers) =>
          observers.map(observer =>
            Seq[NamedParameter]("contract_id" -> cid.coid, "party" -> observer))
      }
      if (observers.nonEmpty) {
        BatchSql(
          "INSERT INTO contract_observers VALUES ({contract_id}, {party})",
          observers.head,
          observers.tail: _*
        ).execute()(conn)
      }
      ()
    }
    ()
  }
} 
Example 9
Source File: V3__Recompute_Key_Hash.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// Note: package name must correspond exactly to the flyway 'locations' setting, which defaults to
// 'db.migration.postgres' for postgres migrations
package db.migration.postgres

import java.sql.{Connection, ResultSet}

import anorm.{BatchSql, NamedParameter}
import com.daml.lf.data.Ref
import com.daml.lf.transaction.Node.GlobalKey
import com.daml.lf.value.Value.ContractId
import com.daml.platform.store.serialization.{KeyHasher, ValueSerializer}
import org.flywaydb.core.api.migration.{BaseJavaMigration, Context}

class V3__Recompute_Key_Hash extends BaseJavaMigration {

  // the number of contracts proceeded in a batch.
  private val batchSize = 10 * 1000

  def migrate(context: Context): Unit = {
    implicit val conn: Connection = context.getConnection
    updateKeyHashed(loadContractKeys)
  }

  private def loadContractKeys(
      implicit connection: Connection
  ): Iterator[(ContractId, GlobalKey)] = {

    val SQL_SELECT_CONTRACT_KEYS =
      """
      |SELECT
      |  contracts.id as contract_id,
      |  contracts.package_id as package_id,
      |  contracts.name as template_name,
      |  contracts.key as contract_key
      |FROM
      |  contracts
      |WHERE
      |  contracts.key is not null
    """.stripMargin

    val rows: ResultSet = connection.createStatement().executeQuery(SQL_SELECT_CONTRACT_KEYS)

    new Iterator[(ContractId, GlobalKey)] {

      var hasNext: Boolean = rows.next()

      def next(): (ContractId, GlobalKey) = {
        val contractId = ContractId.assertFromString(rows.getString("contract_id"))
        val templateId = Ref.Identifier(
          packageId = Ref.PackageId.assertFromString(rows.getString("package_id")),
          qualifiedName = Ref.QualifiedName.assertFromString(rows.getString("template_name"))
        )
        val key = ValueSerializer
          .deserializeValue(rows.getBinaryStream("contract_key"))
          .assertNoCid(coid => s"Found contract ID $coid in contract key")

        hasNext = rows.next()
        contractId -> GlobalKey(templateId, key.value)
      }
    }

  }

  private def updateKeyHashed(contractKeys: Iterator[(ContractId, GlobalKey)])(
      implicit conn: Connection): Unit = {

    val SQL_UPDATE_CONTRACT_KEYS_HASH =
      """
        |UPDATE
        |  contract_keys
        |SET
        |  value_hash = {valueHash}
        |WHERE
        |  contract_id = {contractId}
      """.stripMargin

    val statements = contractKeys.map {
      case (cid, key) =>
        Seq[NamedParameter]("contractId" -> cid.coid, "valueHash" -> KeyHasher.hashKeyString(key))
    }

    statements.toStream.grouped(batchSize).foreach { batch =>
      BatchSql(
        SQL_UPDATE_CONTRACT_KEYS_HASH,
        batch.head,
        batch.tail: _*
      ).execute()
    }
  }

} 
Example 10
Source File: V5_1__Populate_Event_Data.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// Note: package name must correspond exactly to the flyway 'locations' setting, which defaults to
// 'db.migration.h2database' for h2database migrations
package db.migration.h2database

import java.sql.Connection

import anorm.{BatchSql, NamedParameter}
import com.daml.lf.transaction.{Transaction => Tx}
import com.daml.lf.transaction.Node.NodeCreate
import com.daml.ledger.EventId
import com.daml.lf.data.Ref.LedgerString
import com.daml.platform.store.Conversions._
import db.migration.translation.TransactionSerializer
import org.flywaydb.core.api.migration.{BaseJavaMigration, Context}

class V5_1__Populate_Event_Data extends BaseJavaMigration {

  val SELECT_TRANSACTIONS =
    "select distinct le.transaction_id, le.transaction from contracts c join ledger_entries le  on c.transaction_id = le.transaction_id"

  def loadTransactions(conn: Connection) = {
    val statement = conn.createStatement()
    val rows = statement.executeQuery(SELECT_TRANSACTIONS)

    new Iterator[(LedgerString, Tx.Transaction)] {
      var hasNext: Boolean = rows.next()

      def next(): (LedgerString, Tx.Transaction) = {
        val transactionId = LedgerString.assertFromString(rows.getString("transaction_id"))
        val transaction = TransactionSerializer
          .deserializeTransaction(transactionId, rows.getBinaryStream("transaction"))
          .getOrElse(sys.error(s"failed to deserialize transaction $transactionId"))

        hasNext = rows.next()
        if (!hasNext) {
          statement.close()
        }

        transactionId -> transaction
      }
    }
  }

  private val batchSize = 10 * 1000

  override def migrate(context: Context): Unit = {
    val conn = context.getConnection

    val txs = loadTransactions(conn)
    val data = txs.flatMap {
      case (txId, tx) =>
        tx.nodes.collect {
          case (nodeId, NodeCreate(cid, _, _, signatories, stakeholders, _)) =>
            (cid, EventId(txId, nodeId), signatories, stakeholders -- signatories)
        }
    }

    data.grouped(batchSize).foreach { batch =>
      val updateContractsParams = batch.map {
        case (cid, eventId, _, _) =>
          Seq[NamedParameter]("event_id" -> eventId, "contract_id" -> cid.coid)
      }
      BatchSql(
        "UPDATE contracts SET create_event_id = {event_id} where id = {contract_id}",
        updateContractsParams.head,
        updateContractsParams.tail: _*
      ).execute()(conn)

      val signatories = batch.flatMap {
        case (cid, _, signatories, _) =>
          signatories.map(signatory =>
            Seq[NamedParameter]("contract_id" -> cid.coid, "party" -> signatory))
      }
      BatchSql(
        "INSERT INTO contract_signatories VALUES ({contract_id}, {party})",
        signatories.head,
        signatories.tail: _*
      ).execute()(conn)

      val observers = batch.flatMap {
        case (cid, _, _, observers) =>
          observers.map(observer =>
            Seq[NamedParameter]("contract_id" -> cid.coid, "party" -> observer))
      }
      if (observers.nonEmpty) {
        BatchSql(
          "INSERT INTO contract_observers VALUES ({contract_id}, {party})",
          observers.head,
          observers.tail: _*
        ).execute()(conn)
      }
      ()
    }
    ()
  }
} 
Example 11
Source File: JDBCSink.scala    From BigData-News   with Apache License 2.0 5 votes vote down vote up
package com.vita.spark

import java.sql.{Connection, ResultSet, SQLException, Statement}

import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.{ForeachWriter, Row}

/**
  * 处理从StructuredStreaming中向mysql中写入数据
  */
class JDBCSink(url: String, username: String, password: String) extends ForeachWriter[Row] {

  var statement: Statement = _
  var resultSet: ResultSet = _
  var connection: Connection = _

  override def open(partitionId: Long, version: Long): Boolean = {
    connection = new MySqlPool(url, username, password).getJdbcConn()
    statement = connection.createStatement();
    print("open")
    return true
  }

  override def process(value: Row): Unit = {
    println("process step one")
    val titleName = value.getAs[String]("titleName").replaceAll("[\\[\\]]", "")
    val count = value.getAs[Long]("count")

    val querySql = "select 1 from webCount where titleName = '" + titleName + "'"
    val insertSql = "insert into webCount(titleName,count) values('" + titleName + "' , '" + count + "')"
    val updateSql = "update webCount set count = " + count + " where titleName = '" + titleName + "'"
    println("process step two")
    try {
      //查看连接是否成功
      var resultSet = statement.executeQuery(querySql)
      if (resultSet.next()) {
        println("updateSql")
        statement.executeUpdate(updateSql)
      } else {
        println("insertSql")
        statement.execute(insertSql)
      }

    } catch {
      case ex: SQLException => {
        println("SQLException")
      }
      case ex: Exception => {
        println("Exception")
      }
      case ex: RuntimeException => {
        println("RuntimeException")
      }
      case ex: Throwable => {
        println("Throwable")
      }
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    if (statement == null) {
      statement.close()
    }
    if (connection == null) {
      connection.close()
    }
  }
} 
Example 12
Source File: MySqlPool.scala    From BigData-News   with Apache License 2.0 5 votes vote down vote up
package com.vita.spark

import java.sql.{Connection, DriverManager}
import java.util

import org.apache.log4j.{LogManager, Logger}


/**
  * 从mysql连接池中获取连接
  */
class MySqlPool(url: String, user: String, pwd: String) extends Serializable {
  //连接池连接总数
  private val max = 3

  //每次产生连接数
  private val connectionNum = 1

  //当前连接池已产生的连接数
  private var conNum = 0

  private val pool = new util.LinkedList[Connection]() //连接池

  val LOGGER :Logger = LogManager.getLogger("vita")

  //获取连接
  def getJdbcConn(): Connection = {
    LOGGER.info("getJdbcConn")
    //同步代码块,AnyRef为所有引用类型的基类,AnyVal为所有值类型的基类
    AnyRef.synchronized({
      if (pool.isEmpty) {
        //加载驱动
        preGetConn()
        for (i <- 1 to connectionNum) {
          val conn = DriverManager.getConnection(url, user, pwd)
          pool.push(conn)
          conNum += 1
        }
      }
      pool.poll()
    })
  }

  //释放连接
  def releaseConn(conn: Connection): Unit = {
    pool.push(conn)
  }

  //加载驱动
  private def preGetConn(): Unit = {
    //控制加载
    if (conNum < max && !pool.isEmpty) {
      LOGGER.info("Jdbc Pool has no connection now, please wait a moments!")
      Thread.sleep(2000)
      preGetConn()
    } else {
      Class.forName("com.mysql.jdbc.Driver")
    }
  }
} 
Example 13
Source File: PostgresDialect.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.REAL) {
      Some(FloatType)
    } else if (sqlType == Types.SMALLINT) {
      Some(ShortType)
    } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      Some(StringType)
    } else if (sqlType == Types.ARRAY) {
      val scale = md.build.getLong("scale").toInt
      // postgres array type names start with underscore
      toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
    } else None
  }

  private def toCatalystType(
      typeName: String,
      precision: Int,
      scale: Int): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale))
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
    case t: DecimalType => Some(
      JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
      connection.setAutoCommit(false)
    }

  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
} 
Example 14
Source File: DataSourceUtil.scala    From akka-tools   with MIT License 5 votes vote down vote up
package no.nextgentel.oss.akkatools.persistence.jdbcjournal

import java.sql.Connection
import java.util.concurrent.atomic.AtomicInteger
import javax.sql.DataSource

import liquibase.{Contexts, Liquibase}
import liquibase.database.DatabaseFactory
import liquibase.database.jvm.JdbcConnection
import liquibase.resource.ClassLoaderResourceAccessor
import org.h2.jdbcx.JdbcDataSource

import scala.util.Random

object DataSourceUtil {

  def createDataSource(h2DbName:String, pathToLiquibaseFile:String = "akka-tools-jdbc-journal-liquibase.sql"):DataSource = {

    this.synchronized {
      val dataSource = new JdbcDataSource
      val name = s"$h2DbName-${Random.nextInt(1000)}"
      println(s"****> h2-name: '$name'")
      dataSource.setURL(s"jdbc:h2:mem:$name;mode=oracle;DB_CLOSE_DELAY=-1")
      dataSource.setUser("sa")
      dataSource.setPassword("sa")

      // We need to grab a connection and not release it to prevent the db from being
      // released when no connections are active..
      dataSource.getConnection


      updateDb(dataSource, pathToLiquibaseFile)

      dataSource
    }
  }


  private def createLiquibase(dbConnection: Connection, diffFilePath: String): Liquibase = {
    val database = DatabaseFactory.getInstance.findCorrectDatabaseImplementation(new JdbcConnection(dbConnection))
    val classLoader = DataSourceUtil.getClass.getClassLoader
    val resourceAccessor = new ClassLoaderResourceAccessor(classLoader)
    new Liquibase(diffFilePath, resourceAccessor, database)
  }

  private def updateDb(db: DataSource, diffFilePath: String): Unit = {
    val dbConnection = db.getConnection
    val liquibase = createLiquibase(dbConnection, diffFilePath)
    try {
      liquibase.update(null.asInstanceOf[Contexts])
    } catch {
      case e: Throwable => throw e
    } finally {
      liquibase.forceReleaseLocks()
      dbConnection.rollback()
      dbConnection.close()
    }
  }


} 
Example 15
Source File: DataSourceUtil.scala    From akka-tools   with MIT License 5 votes vote down vote up
package no.nextgentel.oss.akkatools.utils

import java.sql.Connection
import javax.sql.DataSource

import liquibase.{Contexts, Liquibase}
import liquibase.database.DatabaseFactory
import liquibase.database.jvm.JdbcConnection
import liquibase.resource.ClassLoaderResourceAccessor
import org.h2.jdbcx.JdbcDataSource

import scala.util.Random

object DataSourceUtil {

  def createDataSource(h2DbName:String, pathToLiquibaseFile:String = "akka-tools-jdbc-journal-liquibase.sql"):DataSource = {

    this.synchronized {
      val dataSource = new JdbcDataSource
      val name = s"$h2DbName-${Random.nextInt(1000)}"
      println(s"****> h2-name: '$name'")
      dataSource.setURL(s"jdbc:h2:mem:$name;mode=oracle;DB_CLOSE_DELAY=-1")
      dataSource.setUser("sa")
      dataSource.setPassword("sa")

      // We need to grab a connection and not release it to prevent the db from being
      // released when no connections are active..
      dataSource.getConnection


      updateDb(dataSource, pathToLiquibaseFile)

      dataSource
    }
  }


  private def createLiquibase(dbConnection: Connection, diffFilePath: String): Liquibase = {
    val database = DatabaseFactory.getInstance.findCorrectDatabaseImplementation(new JdbcConnection(dbConnection))
    val classLoader = DataSourceUtil.getClass.getClassLoader
    val resourceAccessor = new ClassLoaderResourceAccessor(classLoader)
    new Liquibase(diffFilePath, resourceAccessor, database)
  }

  private def updateDb(db: DataSource, diffFilePath: String): Unit = {
    val dbConnection = db.getConnection
    val liquibase = createLiquibase(dbConnection, diffFilePath)
    try {
      liquibase.update(null.asInstanceOf[Contexts])
    } catch {
      case e: Throwable => throw e
    } finally {
      liquibase.forceReleaseLocks()
      dbConnection.rollback()
      dbConnection.close()
    }
  }


} 
Example 16
Source File: SlickJdbcMigration.scala    From reliable-http-client   with Apache License 2.0 5 votes vote down vote up
package rhttpc.transport.amqpjdbc.slick

import java.io.PrintWriter
import java.lang.reflect.{InvocationHandler, Method, Proxy}
import java.sql.Connection
import java.util.logging.Logger

import javax.sql.DataSource
import org.flywaydb.core.api.migration.{BaseJavaMigration, Context}
import slick.jdbc.JdbcProfile

import scala.concurrent.Await
import scala.concurrent.duration._

trait SlickJdbcMigration extends BaseJavaMigration {

  protected val profile: JdbcProfile

  import profile.api._

  def migrateActions: DBIOAction[Any, NoStream, _ <: Effect]

  override final def migrate(context: Context): Unit = {
    val database = Database.forDataSource(new AlwaysUsingSameConnectionDataSource(context.getConnection), None)
    Await.result(database.run(migrateActions), 10 minute)
  }

}

class AlwaysUsingSameConnectionDataSource(conn: Connection) extends DataSource {
  private val notClosingConnection = Proxy.newProxyInstance(
    ClassLoader.getSystemClassLoader,
    Array[Class[_]](classOf[Connection]),
    SuppressCloseHandler
  ).asInstanceOf[Connection]

  object SuppressCloseHandler extends InvocationHandler {
    override def invoke(proxy: AnyRef, method: Method, args: Array[AnyRef]): AnyRef = {
      if (method.getName != "close") {
        method.invoke(conn, args : _*)
      } else {
        null
      }
    }
  }

  override def getConnection: Connection = notClosingConnection
  override def getConnection(username: String, password: String): Connection = notClosingConnection
  override def unwrap[T](iface: Class[T]): T = conn.unwrap(iface)
  override def isWrapperFor(iface: Class[_]): Boolean = conn.isWrapperFor(iface)

  override def setLogWriter(out: PrintWriter): Unit = ???
  override def getLoginTimeout: Int = ???
  override def setLoginTimeout(seconds: Int): Unit = ???
  override def getParentLogger: Logger = ???
  override def getLogWriter: PrintWriter = ???
} 
Example 17
Source File: DatabaseInitializer.scala    From reliable-http-client   with Apache License 2.0 5 votes vote down vote up
package rhttpc.transport.amqpjdbc.slick.helpers

import java.io.PrintWriter
import java.sql.Connection
import java.util.logging.Logger
import javax.sql.DataSource

import com.typesafe.config.Config
import org.flywaydb.core.Flyway
import slick.jdbc.JdbcBackend

import scala.concurrent.ExecutionContext

class DatabaseInitializer(db: JdbcBackend.Database) {
  def initDatabase()(implicit executionContext: ExecutionContext) = {
    migrateIfNeeded(db)
    db
  }

  private def migrateIfNeeded(db: JdbcBackend.Database) = {
    Flyway.configure
      .dataSource(new DatabaseDataSource(db))
      .baselineOnMigrate(true)
      .load
      .migrate
  }
}

object DatabaseInitializer {
  def apply(config: Config) = {
    val db = JdbcBackend.Database.forConfig("db", config)
    new DatabaseInitializer(db)
  }
}

class DatabaseDataSource(db: JdbcBackend.Database) extends DataSource {
  private val conn = db.createSession().conn

  override def getConnection: Connection = conn
  override def getConnection(username: String, password: String): Connection = conn
  override def unwrap[T](iface: Class[T]): T = conn.unwrap(iface)
  override def isWrapperFor(iface: Class[_]): Boolean = conn.isWrapperFor(iface)

  override def setLogWriter(out: PrintWriter): Unit = ???
  override def getLoginTimeout: Int = ???
  override def setLoginTimeout(seconds: Int): Unit = ???
  override def getParentLogger: Logger = ???
  override def getLogWriter: PrintWriter = ???
} 
Example 18
Source File: BucketExprPartitionStrategy.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.Publisher

case class BucketExprPartitionStrategy(bucketExpressions: Seq[String]) extends JdbcPartitionStrategy {

  override def parts(connFn: () => Connection,
                     query: String,
                     bindFn: (PreparedStatement) => Unit,
                     fetchSize: Int,
                     dialect: JdbcDialect): Seq[Publisher[Seq[Row]]] = {

    bucketExpressions.map { bucketExpression =>
      val partitionedQuery = s""" SELECT * FROM ( $query ) WHERE $bucketExpression """
      new JdbcPublisher(connFn, partitionedQuery, bindFn, fetchSize, dialect)
    }
  }
} 
Example 19
Source File: JdbcPrimitives.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, DriverManager, ResultSet}

import com.sksamuel.exts.Logging
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.schema.StructType

trait JdbcPrimitives extends Logging {

  def connect(url: String): Connection = {
    logger.debug(s"Connecting to jdbc source $url...")
    val conn = DriverManager.getConnection(url)
    logger.debug(s"Connected to $url")
    conn
  }

  def schemaFor(dialect: JdbcDialect, rs: ResultSet): StructType = {
    val schema = JdbcSchemaFns.fromJdbcResultset(rs, dialect)
    logger.trace("Fetched schema:\n" + schema.show())
    schema
  }
} 
Example 20
Source File: JdbcSink.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, DriverManager}

import com.sksamuel.exts.Logging
import com.typesafe.config.ConfigFactory
import io.eels.Sink
import io.eels.component.jdbc.dialect.{GenericJdbcDialect, JdbcDialect}
import io.eels.schema.StructType
import com.sksamuel.exts.OptionImplicits._

object JdbcSink extends Logging {

  private val config = ConfigFactory.load()
  private val warnIfMissingRewriteBatchedStatements = config.getBoolean("eel.jdbc.sink.warnIfMissingRewriteBatchedStatements")

  def apply(url: String, table: String): JdbcSink = {
    if (!url.contains("rewriteBatchedStatements")) {
      if (warnIfMissingRewriteBatchedStatements) {
        logger.warn("JDBC connection string does not contain the property 'rewriteBatchedStatements=true' which can be a major performance boost when writing data via JDBC. " +
          "Add this property to your connection string, or to remove this warning set eel.jdbc.warnIfMissingRewriteBatchedStatements=false")
      }
    }
    JdbcSink(() => DriverManager.getConnection(url), table)
  }
}

case class JdbcSink(connFn: () => Connection,
                    table: String,
                    createTable: Boolean = false,
                    dropTable: Boolean = false,
                    batchSize: Int = 1000, // the number of rows before a commit is made
                    batchesPerCommit: Int = 0, // 0 means commit at the end, otherwise how many batches before a commit
                    dialect: Option[JdbcDialect] = None,
                    threads: Int = 4) extends Sink with Logging {

  private val config = ConfigFactory.load()
  private val bufferSize = config.getInt("eel.jdbc.sink.bufferSize")
  private val autoCommit = config.getBoolean("eel.jdbc.sink.autoCommit")

  def withCreateTable(createTable: Boolean): JdbcSink = copy(createTable = createTable)
  def withDropTable(dropTable: Boolean): JdbcSink = copy(dropTable = dropTable)
  def withBatchSize(batchSize: Int): JdbcSink = copy(batchSize = batchSize)
  def withThreads(threads: Int): JdbcSink = copy(threads = threads)
  def withBatchesPerCommit(commitSize: Int): JdbcSink = copy(batchesPerCommit = batchesPerCommit)
  def withDialect(dialect: JdbcDialect): JdbcSink = copy(dialect = dialect.some)

  override def open(schema: StructType) =
    new JdbcSinkWriter(schema, connFn, table, createTable, dropTable, dialect.getOrElse(new GenericJdbcDialect), threads, batchSize, batchesPerCommit, autoCommit, bufferSize)
} 
Example 21
Source File: RangePartitionStrategy.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.Publisher

case class RangePartitionStrategy(columnName: String,
                                  numberOfPartitions: Int,
                                  min: Long,
                                  max: Long) extends JdbcPartitionStrategy {

  def ranges: Seq[(Long, Long)] = {

    // distribute surplus as evenly as possible across buckets
    // min max + 1 because the min-max range is inclusive
    val surplus = (max - min + 1) % numberOfPartitions
    val gap = (max - min + 1) / numberOfPartitions

    List.tabulate(numberOfPartitions) { k =>
      val start = min + k * gap + Math.min(k, surplus)
      val end = min + ((k + 1) * gap) + Math.min(k + 1, surplus)
      (start, end - 1)
    }
  }

  override def parts(connFn: () => Connection,
                     query: String,
                     bindFn: (PreparedStatement) => Unit,
                     fetchSize: Int,
                     dialect: JdbcDialect): Seq[Publisher[Seq[Row]]] = {

    ranges.map { case (start, end) =>

      val partitionedQuery =
        s"""SELECT * FROM ( $query ) WHERE $start <= $columnName AND $columnName <= $end"""

      new JdbcPublisher(connFn, partitionedQuery, bindFn, fetchSize, dialect)
    }
  }
} 
Example 22
Source File: HashPartitionStrategy.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.Publisher

case class HashPartitionStrategy(hashExpression: String,
                                 numberOfPartitions: Int) extends JdbcPartitionStrategy {

  def partitionedQuery(partNum: Int, query: String): String =
    s"""SELECT * from ($query) WHERE $hashExpression = $partNum""".stripMargin

  override def parts(connFn: () => Connection,
                     query: String,
                     bindFn: (PreparedStatement) => Unit,
                     fetchSize: Int,
                     dialect: JdbcDialect): Seq[Publisher[Seq[Row]]] = {

    for (k <- 0 until numberOfPartitions) yield {
      new JdbcPublisher(connFn, partitionedQuery(k, query), bindFn, fetchSize, dialect)
    }
  }
} 
Example 23
Source File: JdbcInserter.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.Connection

import com.sksamuel.exts.Logging
import com.sksamuel.exts.io.Using
import com.sksamuel.exts.jdbc.ResultSetIterator
import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.schema.StructType

class JdbcInserter(val connFn: () => Connection,
                   val table: String,
                   val schema: StructType,
                   val autoCommit: Boolean,
                   val batchesPerCommit: Int,
                   val dialect: JdbcDialect) extends Logging with Using {

  logger.debug("Connecting to JDBC to insert.. ..")
  private val conn = connFn()
  conn.setAutoCommit(autoCommit)
  logger.debug(s"Connected successfully; autoCommit=$autoCommit")

  private var batches = 0

  def insertBatch(batch: Seq[Row]): Unit = {
    val stmt = conn.prepareStatement(dialect.insertQuery(schema, table))
    try {
      batch.foreach { row =>
        row.values.zipWithIndex.foreach { case (value, k) =>
          dialect.setObject(k, value, row.schema.field(k), stmt, conn)
        }
        stmt.addBatch()
      }
      batches = batches + 1
      stmt.executeBatch()
      if (!autoCommit) conn.commit()
      else if (batches == batchesPerCommit) {
        batches = 0
        conn.commit()
      }
    } catch {
      case t: Throwable =>
        logger.error("Batch failure", t)
        if (!autoCommit)
          conn.rollback()
        throw t
    } finally {
      stmt.close()
    }
  }

  def dropTable(): Unit = using(conn.createStatement)(_.execute(s"DROP TABLE IF EXISTS $table"))

  def tableExists(): Boolean = {
    logger.debug(s"Fetching list of tables to detect if $table exists")
    val tables = ResultSetIterator.strings(conn.getMetaData.getTables(null, null, null, Array("TABLE"))).toList
    val tableNames = tables.map(x => x(3).toLowerCase)
    val exists = tableNames.contains(table.toLowerCase())
    logger.debug(s"${tables.size} tables found; $table exists == $exists")
    exists
  }

  def ensureTableCreated(): Unit = {
    logger.info(s"Ensuring table [$table] is created")

    if (!tableExists()) {
      val sql = dialect.create(schema, table)
      logger.info(s"Creating table $table [$sql]")
      val stmt = conn.createStatement()
      try {
        stmt.executeUpdate(sql)
        if (!autoCommit) conn.commit()
      } catch {
        case t: Throwable =>
          logger.error("Batch failure", t)
          if (!autoCommit)
            conn.rollback()
          throw t
      } finally {
        stmt.close()
      }
    }
  }
} 
Example 24
Source File: JdbcTable.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, DatabaseMetaData, ResultSet}

import com.sksamuel.exts.Logging
import com.sksamuel.exts.io.Using
import com.sksamuel.exts.jdbc.ResultSetIterator
import io.eels.component.jdbc.dialect.{GenericJdbcDialect, JdbcDialect}
import io.eels.schema.{Field, StructType}

case class JdbcTable(tableName: String,
                     dialect: JdbcDialect = new GenericJdbcDialect,
                     catalog: Option[String] = None,
                     dbSchema: Option[String] = None)
                    (implicit conn: Connection) extends Logging with JdbcPrimitives with Using {

  private val dbPrefix: String = if (dbSchema.nonEmpty) dbSchema.get + "." else ""
  private val databaseMetaData: DatabaseMetaData = conn.getMetaData
  private val tables = RsIterator(databaseMetaData.getTables(catalog.orNull, dbSchema.orNull, null, Array("TABLE", "VIEW")))
    .map(_.getString("TABLE_NAME"))

  val candidateTableName: String = tables.find(_.toLowerCase == tableName.toLowerCase).getOrElse(sys.error(s"$tableName not found!"))
  val primaryKeys: Seq[String] = RsIterator(databaseMetaData.getPrimaryKeys(catalog.orNull, dbSchema.orNull, candidateTableName))
    .map(_.getString("COLUMN_NAME")).toSeq

  val schema = StructType(
    JdbcSchemaFns
      .fromJdbcResultset(conn.createStatement().executeQuery(s"SELECT * FROM $dbPrefix$candidateTableName WHERE 1=0"), dialect)
      .fields
      .map { f =>
        Field(name = f.name,
          dataType = f.dataType,
          nullable = f.nullable,
          key = primaryKeys.contains(f.name),
          metadata = f.metadata)
      }
  )

  private case class RsIterator(rs: ResultSet) extends Iterator[ResultSet] {
    def hasNext: Boolean = rs.next()

    def next(): ResultSet = rs
  }

} 
Example 25
Source File: JdbcSource.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, DriverManager, PreparedStatement}

import com.sksamuel.exts.Logging
import com.sksamuel.exts.io.Using
import com.sksamuel.exts.metrics.Timed
import io.eels.{Row, Source}
import io.eels.component.jdbc.dialect.{GenericJdbcDialect, JdbcDialect}
import io.eels.datastream.Publisher
import io.eels.schema.StructType

object JdbcSource {
  def apply(url: String, query: String): JdbcSource = JdbcSource(() => DriverManager.getConnection(url), query)
}

case class JdbcSource(connFn: () => Connection,
                      query: String,
                      bindFn: (PreparedStatement) => Unit = stmt => (),
                      fetchSize: Int = 200,
                      providedSchema: Option[StructType] = None,
                      providedDialect: Option[JdbcDialect] = None,
                      partitionStrategy: JdbcPartitionStrategy = SinglePartitionStrategy)
  extends Source with JdbcPrimitives with Logging with Using with Timed {

  override lazy val schema: StructType = providedSchema.getOrElse(fetchSchema())

  def withBind(bind: (PreparedStatement) => Unit): JdbcSource = copy(bindFn = bind)
  def withFetchSize(fetchSize: Int): JdbcSource = copy(fetchSize = fetchSize)
  def withProvidedSchema(schema: StructType): JdbcSource = copy(providedSchema = Option(schema))
  def withProvidedDialect(dialect: JdbcDialect): JdbcSource = copy(providedDialect = Option(dialect))
  def withPartitionStrategy(strategy: JdbcPartitionStrategy): JdbcSource = copy(partitionStrategy = strategy)

  private def dialect(): JdbcDialect = providedDialect.getOrElse(new GenericJdbcDialect())

  override def parts(): Seq[Publisher[Seq[Row]]] = partitionStrategy.parts(connFn, query, bindFn, fetchSize, dialect())

  def fetchSchema(): StructType = {
    using(connFn()) { conn =>
      val schemaQuery = s"SELECT * FROM ($query) tmp WHERE 1=0"
      using(conn.prepareStatement(schemaQuery)) { stmt =>

        stmt.setFetchSize(fetchSize)
        bindFn(stmt)

        val rs = timed(s"Executing query $query") {
          stmt.executeQuery()
        }

        val schema = schemaFor(dialect(), rs)
        rs.close()
        schema
      }
    }
  }
} 
Example 26
Source File: BucketPartitionStrategy.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.Publisher

case class BucketPartitionStrategy(columnName: String,
                                   values: Set[String]) extends JdbcPartitionStrategy {

  override def parts(connFn: () => Connection,
                     query: String,
                     bindFn: (PreparedStatement) => Unit,
                     fetchSize: Int,
                     dialect: JdbcDialect): Seq[Publisher[Seq[Row]]] = {

    values.map { value =>
      val partitionedQuery = s""" SELECT * FROM ( $query ) WHERE $columnName = '$value' """
      new JdbcPublisher(connFn, partitionedQuery, bindFn, fetchSize, dialect)
    }.toSeq
  }
} 
Example 27
Source File: JdbcSinkWriter.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.Connection
import java.util.concurrent.{Executors, LinkedBlockingQueue, TimeUnit}

import com.sksamuel.exts.Logging
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.schema.{Field, StructType}
import io.eels.{Row, SinkWriter}

class JdbcSinkWriter(schema: StructType,
                     connFn: () => Connection,
                     table: String,
                     createTable: Boolean,
                     dropTable: Boolean,
                     dialect: JdbcDialect,
                     threads: Int,
                     batchSize: Int,
                     batchesPerCommit: Int,
                     autoCommit: Boolean,
                     bufferSize: Int) extends SinkWriter with Logging {
  logger.info(s"Creating Jdbc writer with $threads threads, batch size $batchSize, autoCommit=$autoCommit")
  require(bufferSize >= batchSize)

  private val Sentinel = Row(StructType(Field("____jdbcsentinel")), Seq(null))

  import com.sksamuel.exts.concurrent.ExecutorImplicits._

  // the buffer is a concurrent receiver for the write method. It needs to hold enough elements so that
  // the invokers of this class can keep pumping in rows while we wait for a buffer to fill up.
  // the buffer size must be >= batch size or we'll never fill up enough to trigger a batch
  private val buffer = new LinkedBlockingQueue[Row](bufferSize)

  // the coordinator pool is just a single thread that runs the coordinator
  private val coordinatorPool = Executors.newSingleThreadExecutor()

  private lazy val inserter = {
    val inserter = new JdbcInserter(connFn, table, schema, autoCommit, batchesPerCommit, dialect)
    if (dropTable) {
      inserter.dropTable()
    }
    if (createTable) {
      inserter.ensureTableCreated()
    }
    inserter
  }

  // todo this needs to allow multiple batches at once
  coordinatorPool.submit {
    try {
      logger.debug("Starting JdbcWriter Coordinator")
      // once we receive the pill its all over for the writer
      Iterator.continually(buffer.take)
        .takeWhile(_ != Sentinel)
        .grouped(batchSize).withPartial(true)
        .foreach { batch =>
          inserter.insertBatch(batch)
        }
      logger.debug("Write completed; shutting down coordinator")
    } catch {
      case t: Throwable =>
        logger.error("Some error in coordinator", t)
    }
  }
  // the coordinate only runs the one task, that is to read from the buffer
  // and do the inserts
  coordinatorPool.shutdown()

  override def close(): Unit = {
    buffer.put(Sentinel)
    logger.info("Closing JDBC Writer... waiting on writes to finish")
    coordinatorPool.awaitTermination(1, TimeUnit.DAYS)
  }

  // when we get a row to write, we won't commit it immediately to the database,
  // but we'll buffer it so we can do batched inserts
  override def write(row: Row): Unit = {
    buffer.put(row)
  }
} 
Example 28
Source File: JdbcPublisher.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.concurrent.atomic.AtomicBoolean

import com.sksamuel.exts.io.Using
import com.sksamuel.exts.metrics.Timed
import io.eels.Row
import io.eels.component.jdbc.dialect.JdbcDialect
import io.eels.datastream.{Publisher, Subscriber, Subscription}

import scala.collection.mutable.ArrayBuffer

class JdbcPublisher(connFn: () => Connection,
                    query: String,
                    bindFn: (PreparedStatement) => Unit,
                    fetchSize: Int,
                    dialect: JdbcDialect
              ) extends Publisher[Seq[Row]] with Timed with JdbcPrimitives with Using {

  override def subscribe(subscriber: Subscriber[Seq[Row]]): Unit = {
    try {
      using(connFn()) { conn =>

        logger.debug(s"Preparing query $query")
        using(conn.prepareStatement(query)) { stmt =>

          stmt.setFetchSize(fetchSize)
          bindFn(stmt)

          logger.debug(s"Executing query $query")
          using(stmt.executeQuery()) { rs =>

            val schema = schemaFor(dialect, rs)

            val running = new AtomicBoolean(true)
            subscriber.subscribed(Subscription.fromRunning(running))

            val buffer = new ArrayBuffer[Row](fetchSize)
            while (rs.next && running.get) {
              val values = schema.fieldNames().map { name =>
                val raw = rs.getObject(name)
                dialect.sanitize(raw)
              }
              buffer append Row(schema, values)
              if (buffer.size == fetchSize) {
                subscriber.next(buffer.toVector)
                buffer.clear()
              }
            }

            if (buffer.nonEmpty)
              subscriber.next(buffer.toVector)

            subscriber.completed()
          }
        }
      }
    } catch {
      case t: Throwable => subscriber.error(t)
    }
  }
} 
Example 29
Source File: TiDBUtils.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import java.sql.{Connection, Driver, DriverManager}
import java.util.Properties

import com.pingcap.tispark.write.TiDBOptions
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper}

import scala.util.Try

object TiDBUtils {
  private val TIDB_DRIVER_CLASS = "com.mysql.jdbc.Driver"

  
  def createConnectionFactory(jdbcURL: String): () => Connection = {
    import scala.collection.JavaConverters._
    val driverClass: String = TIDB_DRIVER_CLASS
    () => {
      DriverRegistry.register(driverClass)
      val driver: Driver = DriverManager.getDrivers.asScala
        .collectFirst {
          case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
          case d if d.getClass.getCanonicalName == driverClass => d
        }
        .getOrElse {
          throw new IllegalStateException(
            s"Did not find registered driver with class $driverClass")
        }
      driver.connect(jdbcURL, new Properties())
    }
  }
} 
Example 30
Source File: PredefinedTag.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.io.InputStream
import java.sql.Connection

import play.api.Logger
import play.api.libs.json.{Json, OFormat}

case class PredefinedTag(property: Option[String],
                         value: String,
                         solrIndexName: Option[String],
                         exported: Option[Boolean]) {

}

object PredefinedTag {

  val logger = Logger(getClass)

  implicit val jsonFormat: OFormat[PredefinedTag] = Json.format[PredefinedTag]

  def fromStream(stream: InputStream): Seq[PredefinedTag] = {
    try {
      Json.parse(stream).as[Seq[PredefinedTag]]
    } finally {
      stream.close()
    }
  }

  def updateInDB(predefinedTags: Seq[PredefinedTag])(implicit connection: Connection): (Seq[InputTagId], Seq[InputTag]) = {
    val indexIdsByName = SolrIndex.listAll.map(i => i.name -> i.id).toMap
    val tagsInDBByContent = InputTag.loadAll().map(t => t.tagContent -> t).toMap

    val newTags = predefinedTags.map { tag =>
      TagContent(tag.solrIndexName.flatMap(indexIdsByName.get), tag.property, tag.value) -> tag
    }.toMap

    val toDelete = tagsInDBByContent.filter { case (content, tag) => tag.predefined && !newTags.contains(content) }.map(_._2.id).toSeq
    val toInsert = newTags.filter(t => !tagsInDBByContent.contains(t._1)).map { case (tc, t) =>
      InputTag.create(tc.solrIndexId, t.property, t.value, t.exported.getOrElse(true), predefined = true)
    }.toSeq

    InputTag.insert(toInsert: _*)
    InputTag.deleteByIds(toDelete)
    if (toDelete.nonEmpty || toInsert.nonEmpty) {
      logger.info(s"Inserted ${toInsert.size} new predefined tags into the DB and deleted ${toDelete.size} no longer existing predefined tags.")
    }

    (toDelete, toInsert)
  }

} 
Example 31
Source File: SearchInputWithRules.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection
import models.rules._
import play.api.libs.json.{Json, OFormat}

case class SearchInputWithRules(id: SearchInputId,
                                term: String,
                                synonymRules: List[SynonymRule] = Nil,
                                upDownRules: List[UpDownRule] = Nil,
                                filterRules: List[FilterRule] = Nil,
                                deleteRules: List[DeleteRule] = Nil,
                                redirectRules: List[RedirectRule] = Nil,
                                tags: Seq[InputTag] = Seq.empty,
                                isActive: Boolean,
                                comment: String) {

  lazy val trimmedTerm: String = term.trim()

  def allRules: List[Rule] = {
    synonymRules ++ upDownRules ++ filterRules ++ deleteRules ++ redirectRules
  }

  def hasAnyActiveRules: Boolean = {
    allRules.exists(r => r.isActive)
  }

}

object SearchInputWithRules {

  implicit val jsonFormat: OFormat[SearchInputWithRules] = Json.format[SearchInputWithRules]

  def loadById(id: SearchInputId)(implicit connection: Connection): Option[SearchInputWithRules] = {
    SearchInput.loadById(id).map { input =>
      SearchInputWithRules(input.id, input.term,
        synonymRules = SynonymRule.loadByInputId(id),
        upDownRules = UpDownRule.loadByInputId(id),
        filterRules = FilterRule.loadByInputId(id),
        deleteRules = DeleteRule.loadByInputId(id),
        redirectRules = RedirectRule.loadByInputId(id),
        tags = TagInputAssociation.loadTagsBySearchInputId(id),
        isActive = input.isActive,
        comment = input.comment)
    }
  }

  
  def loadWithUndirectedSynonymsAndTagsForSolrIndexId(solrIndexId: SolrIndexId)(implicit connection: Connection): List[SearchInputWithRules] = {
    val inputs = SearchInput.loadAllForIndex(solrIndexId)
    val rules = SynonymRule.loadUndirectedBySearchInputIds(inputs.map(_.id))
    val tags = TagInputAssociation.loadTagsBySearchInputIds(inputs.map(_.id))

    inputs.map { input =>
      SearchInputWithRules(input.id, input.term,
        synonymRules = rules.getOrElse(input.id, Nil).toList,
        tags = tags.getOrElse(input.id, Seq.empty),
        isActive = input.isActive,
        comment = input.comment) // TODO consider only transferring "hasComment" for list overview
    }
  }

  def update(searchInput: SearchInputWithRules)(implicit connection: Connection): Unit = {
    SearchInput.update(searchInput.id, searchInput.term, searchInput.isActive, searchInput.comment)

    SynonymRule.updateForSearchInput(searchInput.id, searchInput.synonymRules)
    UpDownRule.updateForSearchInput(searchInput.id, searchInput.upDownRules)
    FilterRule.updateForSearchInput(searchInput.id, searchInput.filterRules)
    DeleteRule.updateForSearchInput(searchInput.id, searchInput.deleteRules)
    RedirectRule.updateForSearchInput(searchInput.id, searchInput.redirectRules)

    TagInputAssociation.updateTagsForSearchInput(searchInput.id, searchInput.tags.map(_.id))
  }

  def delete(id: SearchInputId)(implicit connection: Connection): Int = {
    val deleted = SearchInput.delete(id)
    if (deleted > 0) {
      for (rule <- Rule.allRules) {
        rule.deleteBySearchInput(id)
      }
      TagInputAssociation.deleteBySearchInputId(id)
    }
    deleted
  }

} 
Example 32
Source File: SuggestedSolrField.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection
import java.time.LocalDateTime

import anorm.SqlParser.get
import anorm._
import play.api.libs.json.{Json, OFormat}

class SuggestedSolrFieldId(id: String) extends Id(id)
object SuggestedSolrFieldId extends IdObject[SuggestedSolrFieldId](new SuggestedSolrFieldId(_))


case class SuggestedSolrField(id: SuggestedSolrFieldId = SuggestedSolrFieldId(),
                         name: String) {

}

object SuggestedSolrField {

  implicit val jsonFormat: OFormat[SuggestedSolrField] = Json.format[SuggestedSolrField]

  val TABLE_NAME = "suggested_solr_field"
  val ID = "id"
  val NAME = "name"
  val SOLR_INDEX_ID = "solr_index_id"
  val LAST_UPDATE = "last_update"

  val sqlParser: RowParser[SuggestedSolrField] = {
    get[SuggestedSolrFieldId](s"$TABLE_NAME.$ID") ~
      get[String](s"$TABLE_NAME.$NAME") map { case id ~ name =>
      SuggestedSolrField(id, name)
    }
  }

  def listAll(solrIndexId: SolrIndexId)(implicit connection: Connection): List[SuggestedSolrField] = {
    SQL"select * from #$TABLE_NAME where #$SOLR_INDEX_ID = $solrIndexId order by #$NAME asc".as(sqlParser.*)
  }

  def insert(solrIndexId: SolrIndexId, fieldName: String)(implicit connection: Connection): SuggestedSolrField = {
    val field = SuggestedSolrField(SuggestedSolrFieldId(), fieldName)
    SQL(s"insert into $TABLE_NAME($ID, $NAME, $SOLR_INDEX_ID, $LAST_UPDATE) values ({$ID}, {$NAME}, {$SOLR_INDEX_ID}, {$LAST_UPDATE})")
      .on(
        ID -> field.id,
        NAME -> fieldName,
        SOLR_INDEX_ID -> solrIndexId,
        LAST_UPDATE -> LocalDateTime.now()
      )
      .execute()
    field
  }


} 
Example 33
Source File: SynonymRule.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models.rules

import java.sql.Connection

import anorm.SqlParser.get
import anorm._
import models.{Id, IdObject, SearchInputId}
import play.api.libs.json.{Json, OFormat}

class SynonymRuleId(id: String) extends Id(id)
object SynonymRuleId extends IdObject[SynonymRuleId](new SynonymRuleId(_))

case class SynonymRule(id: SynonymRuleId = SynonymRuleId(),
                       synonymType: Int,
                       term: String,
                       isActive: Boolean) extends RuleWithTerm {

  override def toNamedParameters(searchInputId: SearchInputId): Seq[NamedParameter] = {
    super.toNamedParameters(searchInputId) ++ Seq[NamedParameter](
      SynonymRule.TYPE -> synonymType
    )
  }
}

object SynonymRule extends RuleObjectWithTerm[SynonymRule] {

  val TABLE_NAME = "synonym_rule"
  val TYPE = "synonym_type"

  val TYPE_UNDIRECTED = 0
  val TYPE_DIRECTED = 1

  implicit val jsonFormat: OFormat[SynonymRule] = Json.format[SynonymRule]

  override def fieldNames: Seq[String] = super.fieldNames :+ TYPE

  val sqlParser: RowParser[SynonymRule] = {
    get[SynonymRuleId](s"$TABLE_NAME.$ID") ~
      get[Int](s"$TABLE_NAME.$TYPE") ~
      get[String](s"$TABLE_NAME.$TERM") ~
      get[Int](s"$TABLE_NAME.$STATUS") map { case id ~ synonymType ~ term ~ status =>
        SynonymRule(id, synonymType, term, isActiveFromStatus(status))
    }
  }

  def loadUndirectedBySearchInputIds(ids: Seq[SearchInputId])(implicit connection: Connection): Map[SearchInputId, Seq[SynonymRule]] = {
    ids.grouped(100).toSeq.flatMap { idGroup =>
      SQL"select * from #$TABLE_NAME where #$TYPE = #$TYPE_UNDIRECTED AND #$SEARCH_INPUT_ID in ($idGroup)".as((sqlParser ~ get[SearchInputId](SEARCH_INPUT_ID)).*).map { case rule ~ id =>
        id -> rule
      }
    }.groupBy(_._1).mapValues(_.map(_._2))
  }

} 
Example 34
Source File: SearchInput.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection
import java.time.LocalDateTime

import anorm.SqlParser.get
import anorm._

class SearchInputId(id: String) extends Id(id)
object SearchInputId extends IdObject[SearchInputId](new SearchInputId(_))

case class SearchInput(id: SearchInputId = SearchInputId(),
                       solrIndexId: SolrIndexId,
                       term: String,
                       lastUpdate: LocalDateTime,
                       isActive: Boolean,
                       comment: String) {

  import SearchInput._

  def status: Int = statusFromIsActive(isActive)

  def toNamedParameters: Seq[NamedParameter] = Seq(
    ID -> id,
    SOLR_INDEX_ID -> solrIndexId,
    TERM -> term,
    LAST_UPDATE -> lastUpdate,
    STATUS -> status,
    COMMENT -> comment
  )

}

object SearchInput {

  val TABLE_NAME = "search_input"
  val ID = "id"
  val TERM = "term"
  val SOLR_INDEX_ID = "solr_index_id"
  val LAST_UPDATE = "last_update"
  val STATUS = "status"
  val COMMENT = "comment"

  def isActiveFromStatus(status: Int): Boolean = {
    (status & 0x01) == 0x01
  }

  def statusFromIsActive(isActive: Boolean) = {
    if (isActive) 0x01 else 0x00
  }

  val sqlParser: RowParser[SearchInput] = {
    get[SearchInputId](s"$TABLE_NAME.$ID") ~
      get[String](s"$TABLE_NAME.$TERM") ~
      get[SolrIndexId](s"$TABLE_NAME.$SOLR_INDEX_ID") ~
      get[LocalDateTime](s"$TABLE_NAME.$LAST_UPDATE") ~
      get[Int](s"$TABLE_NAME.$STATUS") ~
      get[String](s"$TABLE_NAME.$COMMENT") map { case id ~ term ~ indexId ~ lastUpdate ~ status ~ comment =>
        SearchInput(id, indexId, term, lastUpdate, isActiveFromStatus(status), comment)
    }
  }

  def insert(solrIndexId: SolrIndexId, term: String)(implicit connection: Connection): SearchInput = {
    val input = SearchInput(SearchInputId(), solrIndexId, term, LocalDateTime.now(), true, "")
    SQL(s"insert into $TABLE_NAME ($ID, $TERM, $SOLR_INDEX_ID, $LAST_UPDATE, $STATUS, $COMMENT) values ({$ID}, {$TERM}, {$SOLR_INDEX_ID}, {$LAST_UPDATE}, {$STATUS}, {$COMMENT})")
      .on(input.toNamedParameters: _*).execute()
    input
  }

  def loadAllForIndex(solrIndexId: SolrIndexId)(implicit connection: Connection): List[SearchInput] = {
    SQL"select * from #$TABLE_NAME where #$SOLR_INDEX_ID = $solrIndexId order by #$TERM asc".as(sqlParser.*)
  }

  def loadAllIdsForIndex(solrIndexId: SolrIndexId)(implicit connection: Connection): List[SearchInputId] = {
    SQL"select #$ID from #$TABLE_NAME where #$SOLR_INDEX_ID = $solrIndexId order by #$TERM asc".as(get[SearchInputId](ID).*)
  }

  def loadById(id: SearchInputId)(implicit connection: Connection): Option[SearchInput] = {
    SQL"select * from #$TABLE_NAME where #$ID = $id".as(sqlParser.*).headOption
  }

  def update(id: SearchInputId, term: String, isActive: Boolean, comment: String)(implicit connection: Connection): Unit = {
    SQL"update #$TABLE_NAME set #$TERM = $term, #$LAST_UPDATE = ${LocalDateTime.now()}, #$STATUS = ${statusFromIsActive(isActive)}, #$COMMENT = $comment where #$ID = $id".executeUpdate()
  }

  
  def delete(id: SearchInputId)(implicit connection: Connection): Int = {
    SQL"delete from #$TABLE_NAME where #$ID = $id".executeUpdate()
  }

} 
Example 35
Source File: InputTag.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection
import java.time.LocalDateTime

import anorm._
import anorm.SqlParser.get
import play.api.libs.json._

class InputTagId(id: String) extends Id(id)
object InputTagId extends IdObject[InputTagId](new InputTagId(_))



case class InputTag(id: InputTagId,
                    solrIndexId: Option[SolrIndexId],
                    property: Option[String],
                    value: String,
                    exported: Boolean,
                    predefined: Boolean,
                    lastUpdate: LocalDateTime) {

  import InputTag._

  def toNamedParameters: Seq[NamedParameter] = Seq(
    ID -> id,
    SOLR_INDEX_ID -> solrIndexId,
    PROPERTY -> property,
    VALUE -> value,
    EXPORTED -> (if (exported) 1 else 0),
    PREDEFINED -> (if (predefined) 1 else 0),
    LAST_UPDATE -> lastUpdate
  )

  def tagContent = TagContent(solrIndexId, property, value)

  def displayValue: String = property.map(p => s"$p:").getOrElse("") + value

}

object InputTag {

  val TABLE_NAME = "input_tag"

  val ID = "id"
  val SOLR_INDEX_ID = "solr_index_id"
  val PROPERTY = "property"
  val VALUE = "tag_value"
  val EXPORTED = "exported"
  val PREDEFINED = "predefined"
  val LAST_UPDATE = "last_update"

  implicit val jsonReads: Reads[InputTag] = Json.reads[InputTag]

  private val defaultWrites: OWrites[InputTag] = Json.writes[InputTag]
  implicit val jsonWrites: OWrites[InputTag] = OWrites[InputTag] { tag =>
    Json.obj("displayValue" -> tag.displayValue) ++ defaultWrites.writes(tag)
  }

  def create(solrIndexId: Option[SolrIndexId],
             property: Option[String],
             value: String,
             exported: Boolean,
             predefined: Boolean = false): InputTag = {
    InputTag(InputTagId(), solrIndexId, property, value, exported, predefined, LocalDateTime.now())
  }

  val sqlParser: RowParser[InputTag] = get[InputTagId](s"$TABLE_NAME.$ID") ~
    get[Option[SolrIndexId]](s"$TABLE_NAME.$SOLR_INDEX_ID") ~
    get[Option[String]](s"$TABLE_NAME.$PROPERTY") ~
    get[String](s"$TABLE_NAME.$VALUE") ~
    get[Int](s"$TABLE_NAME.$EXPORTED") ~
    get[Int](s"$TABLE_NAME.$PREDEFINED") ~
    get[LocalDateTime](s"$TABLE_NAME.$LAST_UPDATE") map { case id ~ solrIndexId ~ property ~ value ~ exported ~ predefined ~ lastUpdate =>
      InputTag(id, solrIndexId, property,
        value, exported > 0, predefined > 0, lastUpdate)
  }

  def insert(tags: InputTag*)(implicit connection: Connection): Unit = {
    if (tags.nonEmpty) {
      BatchSql(s"insert into $TABLE_NAME ($ID, $SOLR_INDEX_ID, $PROPERTY, $VALUE, $EXPORTED, $PREDEFINED, $LAST_UPDATE) " +
        s"values ({$ID}, {$SOLR_INDEX_ID}, {$PROPERTY}, {$VALUE}, {$EXPORTED}, {$PREDEFINED}, {$LAST_UPDATE})",
        tags.head.toNamedParameters,
        tags.tail.map(_.toNamedParameters): _*
      ).execute()
    }
  }

  def loadAll()(implicit connection: Connection): Seq[InputTag] = {
    SQL(s"select * from $TABLE_NAME order by $PROPERTY asc, $VALUE asc")
      .as(sqlParser.*)
  }

  def deleteByIds(ids: Seq[InputTagId])(implicit connection: Connection): Unit = {
    for (idGroup <- ids.grouped(100)) {
      SQL"delete from #$TABLE_NAME where #$ID in ($idGroup)".executeUpdate()
    }
  }


} 
Example 36
Source File: SearchInputWithRulesSpec.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection

import models.rules.{SynonymRule, SynonymRuleId}
import org.scalatest.{FlatSpec, Matchers}
import utils.WithInMemoryDB

class SearchInputWithRulesSpec extends FlatSpec with Matchers with WithInMemoryDB with TestData {

  private val tag = InputTag.create(None, Some("tenant"), "MO", exported = true)

  "SearchInputWithRules" should "load lists with hundreds of entries successfully" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)
      SolrIndex.insert(indexEn)
      InputTag.insert(tag)

      insertInputs(300, indexDe.id, "term_de")
      insertInputs(200, indexEn.id, "term_en")

      val inputsDe = SearchInputWithRules.loadWithUndirectedSynonymsAndTagsForSolrIndexId(indexDe.id)
      inputsDe.size shouldBe 300
      for (input <- inputsDe) {
        input.term should startWith("term_de_")
        input.tags.size shouldBe 1
        input.tags.head.displayValue shouldBe "tenant:MO"
        input.synonymRules.size shouldBe 1 // Only undirected synonyms should be loaded
        input.synonymRules.head.term should startWith("term_de_synonym_")
      }

      SearchInputWithRules.loadWithUndirectedSynonymsAndTagsForSolrIndexId(indexEn.id).size shouldBe 200
    }
  }

  private def insertInputs(count: Int, indexId: SolrIndexId, termPrefix: String)(implicit conn: Connection): Unit = {
    for (i <- 0 until count) {
      val input = SearchInput.insert(indexId, s"${termPrefix}_$i")
      SynonymRule.updateForSearchInput(input.id, Seq(
        SynonymRule(SynonymRuleId(), SynonymRule.TYPE_UNDIRECTED, s"${termPrefix}_synonym_$i", isActive = true),
        SynonymRule(SynonymRuleId(), SynonymRule.TYPE_DIRECTED, s"${termPrefix}_directedsyn_$i", isActive = true),
      ))
      TagInputAssociation.updateTagsForSearchInput(input.id, Seq(tag.id))
    }
  }

  "SearchInputWithRules" should "be (de)activatable" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)

      val input = SearchInput.insert(indexDe.id, "my input")
      input.isActive shouldBe true

      SearchInput.update(input.id, input.term, false, input.comment)
      SearchInput.loadById(input.id).get.isActive shouldBe false

      SearchInput.update(input.id, input.term, true, input.comment)
      SearchInput.loadById(input.id).get.isActive shouldBe true
    }
  }

  "SearchInputWithRules" should "have a modifiable comment" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)

      val input = SearchInput.insert(indexDe.id, "my input")
      input.comment shouldBe ""

      SearchInput.update(input.id, input.term, input.isActive, "My #magic comment.")
      SearchInput.loadById(input.id).get.comment shouldBe "My #magic comment."

      SearchInput.update(input.id, input.term, input.isActive, "My #magic comment - updated.")
      SearchInput.loadById(input.id).get.comment shouldBe "My #magic comment - updated."
    }
  }

} 
Example 37
Source File: PostgresConnection.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.connector.postgres

import java.sql.{Connection, DriverManager}

import com.typesafe.config.Config

trait PostgresConnection {

  private var connectionUrl : String = ""
  private val driverName : String = "org.postgresql.Driver"

  protected def setConnectionConfig(config : Config) = {
    val db = config.getString(ConfigurationKeys.DATABASE)
    val host = config.getString(ConfigurationKeys.HOST)
    val user = config.getString(ConfigurationKeys.USER)
    val password = config.getString(ConfigurationKeys.PASSWORD)
    connectionUrl = s"jdbc:postgresql://$host/$db?user=$user&password=$password"
  }

  protected def getConnection: Connection = {
    Class.forName(driverName)
    val connection: Connection = DriverManager.getConnection(connectionUrl)
    connection
  }
} 
Example 38
Source File: PostgresDialect.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.REAL) {
      Some(FloatType)
    } else if (sqlType == Types.SMALLINT) {
      Some(ShortType)
    } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      Some(StringType)
    } else if (sqlType == Types.ARRAY) {
      val scale = md.build.getLong("scale").toInt
      // postgres array type names start with underscore
      toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
    } else None
  }

  private def toCatalystType(
      typeName: String,
      precision: Int,
      scale: Int): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" | "decimal" if precision > 0 => Some(DecimalType.bounded(precision, scale))
    case "numeric" | "decimal" =>
      // SPARK-26538: handle numeric without explicit precision and scale.
      Some(DecimalType. SYSTEM_DEFAULT)
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
    case t: DecimalType => Some(
      JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)

  
  override def getTruncateQuery(
      table: String,
      cascade: Option[Boolean] = isCascadingTruncateTable): String = {
    cascade match {
      case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE"
      case _ => s"TRUNCATE TABLE ONLY $table"
    }
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
      connection.setAutoCommit(false)
    }
  }

} 
Example 39
Source File: Alarm.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.alarm

import java.sql.Connection

import org.apache.spark.monitor.MonitorItem.MonitorItem

trait Alarm {
  val name: String
  var options: Map[String, String] = _
  def bind(options: Map[String, String]): Alarm = {
    this.options = options
    this
  }

  
class AlertMessage(val title: MonitorItem) {
  def toCsv(): String = {
    throw new Exception("can not treat as csv")
  }
  def toHtml(): String = {
    ""
  }
  def toJdbc(conn: Connection, appId: String = ""): Unit = {
    // do nothing
  }
}

class HtmlMessage(title: MonitorItem, content: String) extends AlertMessage(title) {
  override def toHtml(): String = {
    content
  }
}

case class AlertResp(status: Boolean, ret: String)

object AlertResp {
  def success(ret: String): AlertResp = apply(status = true, ret)
  def failure(ret: String): AlertResp = apply(status = false, ret)
}
object AlertType extends Enumeration {
  type AlertType = Value
  val Application, Job, Stage, Task, Executor, SQL = Value
}
object JobType extends Enumeration {
  type JobType = Value
  val CORE, SQL, STREAMING, MLLIB, GRAPHX = Value
} 
Example 40
Source File: JdbcAlarm.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.alarm

import java.sql.{Connection, DriverManager}

class JdbcAlarm extends Alarm {
  override val name: String = "mysql"

  private val conn: Connection = getConnect

  private def getConnect(): Connection = {
    org.apache.spark.util.Utils.classForName("com.mysql.jdbc.Driver")
    DriverManager.getConnection(
      "jdbc:mysql://localhost:3306/xsql_monitor?useSSL=true",
      "xsql_monitor",
      "xsql_monitor")
  }

  
  override def alarm(msg: AlertMessage): AlertResp = {
    msg.toJdbc(conn)
    AlertResp.success("")
  }

  override def finalAlarm(msg: AlertMessage): AlertResp = {
    msg.toJdbc(conn)
    AlertResp.success("")
  }
} 
Example 41
Source File: ApplicationMonitor.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.monitor.application

import java.sql.{Connection, Timestamp}
import java.text.SimpleDateFormat
import java.util.Date
import java.util.concurrent.TimeUnit

import scala.concurrent.duration.Duration

import org.apache.spark.alarm.AlertMessage
import org.apache.spark.alarm.AlertType._
import org.apache.spark.monitor.Monitor
import org.apache.spark.monitor.MonitorItem.MonitorItem

abstract class ApplicationMonitor extends Monitor {
  override val alertType = Seq(Application)
}

class ApplicationInfo(
    title: MonitorItem,
    appName: String,
    appId: String,
    md5: String,
    startTime: Date,
    duration: Long,
    appUiUrl: String,
    historyUrl: String,
    eventLogDir: String,
    minExecutor: Int,
    maxExecutor: Int,
    executorCore: Int,
    executorMemoryMB: Long,
    executorAccu: Double,
    user: String)
  extends AlertMessage(title) {
  override def toCsv(): String = {
    s"${user},${appId}," +
      s"${new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(startTime)}," +
      s"${Duration(duration, TimeUnit.MILLISECONDS).toSeconds}," +
      s"${executorMemoryMB},${executorCore},${executorAccu.formatted("%.2f")},${appName}"
  }
  // scalastyle:off
  override def toHtml(): String = {
    val html = <h1>任务完成! </h1>
        <h2>任务信息 </h2>
        <ul>
          <li>作业名:{appName}</li>
          <li>作业ID:{appId}</li>
          <li>开始时间:{new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(startTime)}</li>
          <li>任务用时:{Duration(duration, TimeUnit.MILLISECONDS).toSeconds} s</li>
        </ul>
        <h2>资源用量</h2>
        <ul>
          <li>Executor个数:{minExecutor}~{maxExecutor}</li>
          <li>Executor内存:{executorMemoryMB} MB</li>
          <li>Executor核数:{executorCore}</li>
          <li>Executor累积用量:{executorAccu.formatted("%.2f")} executor*min</li>
        </ul>
        <h2>调试信息</h2>
        <ul>
          <li>回看链接1:<a href={appUiUrl.split(",").head}>{appUiUrl.split(",").head}</a></li>
          <li>回看链接2:<a href={historyUrl}>{historyUrl}</a></li>
          <li>日志文件所在目录:{eventLogDir}</li>
        </ul>
    html.mkString
  }

  override def toJdbc(conn: Connection, appId: String): Unit = {
    val query = "INSERT INTO `xsql_monitor`.`spark_history`(" +
      "`user`, `md5`, `appId`, `startTime`, `duration`, " +
      "`yarnURL`, `sparkHistoryURL`, `eventLogDir`, `coresPerExecutor`, `memoryPerExecutorMB`," +
      " `executorAcc`, `appName`, `minExecutors`, `maxExecutors`)" +
      " SELECT ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? FROM DUAL" +
      " WHERE NOT EXISTS (SELECT * FROM `xsql_monitor`.`spark_history` WHERE `appId` = ?);"

    val preparedStmt = conn.prepareStatement(query)
    preparedStmt.setString(1, user)
    preparedStmt.setString(2, md5)
    preparedStmt.setString(3, appId)
    preparedStmt.setTimestamp(4, new Timestamp(startTime.getTime))
    preparedStmt.setLong(5, Duration(duration, TimeUnit.MILLISECONDS).toSeconds)
    preparedStmt.setString(6, appUiUrl)
    preparedStmt.setString(7, historyUrl)
    preparedStmt.setString(8, eventLogDir)
    preparedStmt.setInt(9, executorCore)
    preparedStmt.setLong(10, executorMemoryMB)
    preparedStmt.setDouble(11, executorAccu)
    preparedStmt.setString(12, appName)
    preparedStmt.setInt(13, minExecutor)
    preparedStmt.setInt(14, maxExecutor)
    preparedStmt.setString(15, appId)
    preparedStmt.execute
  }
} 
Example 42
Source File: ThriftServerTest.scala    From Hive-JDBC-Proxy   with Apache License 2.0 5 votes vote down vote up
package com.enjoyyin.hive.proxy.jdbc.test

import java.sql.{Connection, DriverManager, ResultSet, Statement}

import com.enjoyyin.hive.proxy.jdbc.util.Utils


private object ThriftServerTest extends App {
  val sql = """show tables"""
  val test_url = "jdbc:hive2://localhost:10001/default"
  Class.forName("org.apache.hive.jdbc.HiveDriver")
  def test(index: Int) = {
    var conn: Connection = null
    var stmt: Statement = null
    var rs: ResultSet = null
    Utils.tryFinally {
      conn = DriverManager.getConnection(test_url, "hduser0009", "")
      stmt = conn.createStatement
      rs = stmt.executeQuery(sql)
      while(rs.next) {
        println ("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", " + index + ".tables => " + rs.getObject(1))
      }
      println("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", ready to close " + index)
    } {
      if(rs != null) Utils.tryIgnoreError(rs.close())
      if(stmt != null) Utils.tryIgnoreError(stmt.close())
      if(conn != null) Utils.tryIgnoreError(conn.close())
    }
  }
  (0 until 8).foreach(i => new Thread {
    setName("thread-" + i)
    override def run(): Unit = {
      Utils.tryCatch(test(i)) { t =>
        println("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", " + i + " has occur an error.")
        t.printStackTrace()
      }
    }
  }.start())
} 
Example 43
Source File: ConnectionPool.scala    From airframe   with Apache License 2.0 5 votes vote down vote up
package wvlet.airframe.jdbc

import java.sql.{Connection, PreparedStatement, ResultSet}

import wvlet.log.LogSupport
import wvlet.log.io.IOUtil.withResource

object ConnectionPool {
  def apply(config: DbConfig): ConnectionPool = {
    val pool: ConnectionPool = config.`type` match {
      case "sqlite" => new SQLiteConnectionPool(config)
      case other =>
        new GenericConnectionPool(config)
    }
    pool
  }

  def newFactory: ConnectionPoolFactory = new ConnectionPoolFactory()
}

trait ConnectionPool extends LogSupport with AutoCloseable {
  def config: DbConfig

  def withConnection[U](body: Connection => U): U
  def withTransaction[U](body: Connection => U): U = {
    withConnection { conn =>
      conn.setAutoCommit(false)
      var failed = false
      try {
        body(conn)
      } catch {
        case e: Throwable =>
          // Need to set the failed flag first because the rollback might fail
          failed = true
          conn.rollback()
          throw e
      } finally {
        if (failed == false) {
          conn.commit()
        }
      }
    }
  }

  def stop: Unit

  override def close(): Unit = stop

  def executeQuery[U](sql: String)(handler: ResultSet => U): U = {
    withConnection { conn =>
      withResource(conn.createStatement()) { stmt =>
        debug(s"execute query: ${sql}")
        withResource(stmt.executeQuery(sql)) { rs => handler(rs) }
      }
    }
  }
  def executeUpdate(sql: String): Int = {
    // TODO Add update retry
    withConnection { conn =>
      withResource(conn.createStatement()) { stmt =>
        debug(s"execute update: ${sql}")
        stmt.executeUpdate(sql)
      }
    }
  }

  def queryWith[U](preparedStatement: String)(body: PreparedStatement => Unit)(handler: ResultSet => U): U = {
    withConnection { conn =>
      withResource(conn.prepareStatement(preparedStatement)) { stmt =>
        body(stmt)
        debug(s"execute query: ${preparedStatement}")
        withResource(stmt.executeQuery) { rs => handler(rs) }
      }
    }
  }

  def updateWith(preparedStatement: String)(body: PreparedStatement => Unit): Unit = {
    withConnection { conn =>
      withResource(conn.prepareStatement(preparedStatement)) { stmt =>
        body(stmt)
        stmt.executeUpdate()
      }
    }
  }

} 
Example 44
Source File: SQLiteConnectionPool.scala    From airframe   with Apache License 2.0 5 votes vote down vote up
package wvlet.airframe.jdbc

import java.io.File
import java.sql.{Connection, DriverManager}

import wvlet.log.Guard


class SQLiteConnectionPool(val config: DbConfig) extends ConnectionPool with Guard {
  private var conn: Connection = newConnection

  private def newConnection: Connection = {
    // Prepare parent db folder
    Option(new File(config.database).getParentFile).map { p =>
      if (!p.exists()) {
        info(s"Create db folder: ${p}")
        p.mkdirs()
      }
    }

    val jdbcUrl = config.jdbcUrl
    info(s"Opening ${jdbcUrl}")
    // We need to explicitly load sqlite-jdbc to cope with SBT's peculiar class loader
    Class.forName(config.jdbcDriverName)
    val conn = DriverManager.getConnection(jdbcUrl)
    conn.setAutoCommit(true)
    conn
  }

  def withConnection[U](body: Connection => U): U = {
    guard {
      if (conn.isClosed) {
        conn = newConnection
      }
      // In sqlite-jdbc, we can reuse the same connection instance,
      // and we have no need to close the connection
      body(conn)
    }
  }

  def stop: Unit = {
    info(s"Closing the connection pool for ${config.jdbcUrl}")
    conn.close()
  }
} 
Example 45
Source File: HttpRecord.scala    From airframe   with Apache License 2.0 5 votes vote down vote up
package wvlet.airframe.http.recorder
import java.sql.{Connection, ResultSet}
import java.time.Instant

import com.twitter.finagle.http.{Response, Status, Version}
import com.twitter.io.Buf
import wvlet.airframe.codec._
import wvlet.airframe.control.Control.withResource
import wvlet.airframe.http.recorder.HttpRecord.headerCodec
import wvlet.log.LogSupport


case class HttpRecord(
    session: String,
    requestHash: Int,
    method: String,
    destHost: String,
    path: String,
    requestHeader: Seq[(String, String)],
    requestBody: String,
    responseCode: Int,
    responseHeader: Seq[(String, String)],
    responseBody: String,
    createdAt: Instant
) {
  def summary: String = {
    s"${method}(${responseCode}) ${destHost}${path}: ${responseBody.substring(0, 30.min(responseBody.size))} ..."
  }

  def toResponse: Response = {
    val r = Response(Version.Http11, Status.fromCode(responseCode))

    responseHeader.foreach { x => r.headerMap.set(x._1, x._2) }

    // Decode binary contents with Base64
    val contentBytes = HttpRecordStore.decodeFromBase64(responseBody)
    r.content = Buf.ByteArray.Owned(contentBytes)
    r.contentLength = contentBytes.length
    r
  }

  def insertInto(tableName: String, conn: Connection): Unit = {
    withResource(conn.prepareStatement(s"""|insert into "${tableName}" values(
          |?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
          |)
      """.stripMargin)) { prep =>
      // TODO Implement this logic in JDBCResultSetCodec
      prep.setString(1, session)
      prep.setInt(2, requestHash)
      prep.setString(3, method)
      prep.setString(4, destHost)
      prep.setString(5, path)
      prep.setString(6, JSONCodec.toJson(headerCodec.toMsgPack(requestHeader)))
      prep.setString(7, requestBody)
      prep.setInt(8, responseCode)
      prep.setString(9, JSONCodec.toJson(headerCodec.toMsgPack(responseHeader)))
      prep.setString(10, responseBody)
      prep.setString(11, createdAt.toString)

      prep.execute()
    }
  }
}

object HttpRecord extends LogSupport {
  private[recorder] val headerCodec                               = MessageCodec.of[Seq[(String, String)]]
  private[recorder] val recordCodec                               = MessageCodec.of[HttpRecord]
  private[recorder] def createTableSQL(tableName: String): String =
    // TODO: Add a method to generate this SQL statement in airframe-codec
    s"""create table if not exists "${tableName}" (
       |  session string,
       |  requestHash string,
       |  method string,
       |  destHost string,
       |  path string,
       |  requestHeader string,
       |  requestBody string,
       |  responseCode int,
       |  responseHeader string,
       |  responseBody string,
       |  createdAt string
       |)
     """.stripMargin

  private[recorder] def read(rs: ResultSet): Seq[HttpRecord] = {
    val resultSetCodec = JDBCCodec(rs)
    resultSetCodec
      .mapMsgPackMapRows(msgpack => recordCodec.unpackBytes(msgpack))
      .filter(_.isDefined)
      .map(_.get)
      .toSeq
  }
} 
Example 46
Source File: Database.scala    From schedoscope   with Apache License 2.0 5 votes vote down vote up
package org.schedoscope.test

import java.sql.{Connection, ResultSet, Statement}

import org.schedoscope.dsl.{FieldLike, View}
import org.schedoscope.schema.ddl.HiveQl

import scala.collection.mutable.{HashMap, ListBuffer}

class Database(conn: Connection, url: String) {

  def selectForViewByQuery(v: View, query: String, orderByField: Option[FieldLike[_]]): List[Map[String, Any]] = {
    val res = ListBuffer[Map[String, Any]]()
    var statement: Statement = null
    var rs: ResultSet = null

    try {
      statement = conn.createStatement()
      rs = statement.executeQuery(query)

      while (rs.next()) {
        val row = HashMap[String, Any]()
        v.fields.view.zipWithIndex.foreach(f => {
          row.put(f._1.n, ViewSerDe.deserializeField(f._1.t, rs.getString(f._2 + 1)))
        })
        res.append(row.toMap)
      }
    }
    finally {
      if (rs != null) try {
        rs.close()
      } catch {
        case _: Throwable =>
      }

      if (statement != null) try {
        statement.close()
      } catch {
        case _: Throwable =>
      }
    }

    orderByField match {
      case Some(f) => res.sortBy {
        _ (f.n) match {
          case null => ""
          case other => other.toString
        }
      } toList
      case None => res.toList
    }
  }

  def selectView(v: View, orderByField: Option[FieldLike[_]]): List[Map[String, Any]] =
    selectForViewByQuery(v, HiveQl.selectAll(v), orderByField)

} 
Example 47
Source File: SqlProtocol.scala    From gatling-sql   with Apache License 2.0 5 votes vote down vote up
package io.github.gatling.sql.protocol

import java.sql.Connection

import akka.actor.ActorSystem
import io.gatling.core
import io.gatling.core.CoreComponents
import io.gatling.core.config.GatlingConfiguration
import io.gatling.core.protocol.{Protocol, ProtocolComponents, ProtocolKey}
import io.gatling.core.session.Session


case class SqlProtocol(connection: Connection) extends Protocol {
  type Components = SqlComponents
}

object SqlProtocol {

  val SqlProtocolKey = new ProtocolKey {

    type Protocol = SqlProtocol
    type Components = SqlComponents

    override def protocolClass: Class[core.protocol.Protocol] = classOf[SqlProtocol].asInstanceOf[Class[io.gatling.core.protocol.Protocol]]

    override def defaultProtocolValue(configuration: GatlingConfiguration): SqlProtocol = throw new IllegalStateException("Can't provide a default value for SqlProtocol")

    override def newComponents(system: ActorSystem, coreComponents: CoreComponents): SqlProtocol => SqlComponents = {
      sqlProtocol => SqlComponents(sqlProtocol)
    }
  }
}

case class SqlComponents(sqlProtocol: SqlProtocol) extends ProtocolComponents {
  def onStart: Option[Session => Session] = None
  def onExit: Option[Session => Unit] = None
}

case class SqlProtocolBuilder(connection: Connection) {
  def build() = SqlProtocol(connection)
}

object SqlProtocolBuilder {
  def connection(connection: Connection) = SqlProtocolBuilder(connection)
} 
Example 48
Source File: SqlStatement.scala    From gatling-sql   with Apache License 2.0 5 votes vote down vote up
package io.github.gatling.sql

import java.sql.{Connection, PreparedStatement}

import com.typesafe.scalalogging.StrictLogging
import io.github.gatling.sql.db.ConnectionPool
import io.gatling.commons.validation.Validation
import io.gatling.core.session.{Expression, Session}
import io.gatling.commons.validation._

trait SqlStatement extends StrictLogging {

  def apply(session:Session): Validation[PreparedStatement]

  def connection = ConnectionPool.connection
}

case class SimpleSqlStatement(statement: Expression[String]) extends SqlStatement {
  def apply(session: Session): Validation[PreparedStatement] = statement(session).flatMap { stmt =>
      logger.debug(s"STMT: ${stmt}")
      connection.prepareStatement(stmt).success
    }
} 
Example 49
Source File: BaseThriftIntegrationTestSuite.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.test.framework

import java.sql.{Connection, DriverManager, ResultSet}

class BaseThriftIntegrationTestSuite extends BaseIntegrationTestSuite {
  private var jdbcUri: String = _

  override def beforeAll(): Unit = {
    cluster = Cluster.get()
    // The JDBC endpoint must contain a valid value
    assert(cluster.jdbcEndpoint.isDefined)
    jdbcUri = cluster.jdbcEndpoint.get
  }

  def checkQuery(connection: Connection, query: String)(validate: ResultSet => Unit): Unit = {
    val ps = connection.prepareStatement(query)
    try {
      val rs = ps.executeQuery()
      try {
        validate(rs)
      } finally {
        rs.close()
      }
    } finally {
      ps.close()
    }
  }

  def withConnection[T](f: Connection => T): T = {
    val connection = DriverManager.getConnection(jdbcUri)
    try {
      f(connection)
    } finally {
      connection.close()
    }
  }
} 
Example 50
Source File: ThriftServerBaseTest.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.thriftserver

import java.sql.{Connection, DriverManager, Statement}

import org.apache.hive.jdbc.HiveDriver
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.livy.LivyConf
import org.apache.livy.LivyConf.{LIVY_SPARK_SCALA_VERSION, LIVY_SPARK_VERSION}
import org.apache.livy.server.AccessManager
import org.apache.livy.server.recovery.{SessionStore, StateStore}
import org.apache.livy.sessions.InteractiveSessionManager
import org.apache.livy.utils.LivySparkUtils.{formatSparkVersion, sparkScalaVersion, sparkSubmitVersion}

object ServerMode extends Enumeration {
  val binary, http = Value
}

abstract class ThriftServerBaseTest extends FunSuite with BeforeAndAfterAll {
  def mode: ServerMode.Value
  def port: Int

  val THRIFT_SERVER_STARTUP_TIMEOUT = 30000 // ms

  val livyConf = new LivyConf()
  val (sparkVersion, scalaVersionFromSparkSubmit) = sparkSubmitVersion(livyConf)
  val formattedSparkVersion: (Int, Int) = {
    formatSparkVersion(sparkVersion)
  }

  def jdbcUri(defaultDb: String, sessionConf: String*): String = if (mode == ServerMode.http) {
    s"jdbc:hive2://localhost:$port/$defaultDb?hive.server2.transport.mode=http;" +
      s"hive.server2.thrift.http.path=cliservice;${sessionConf.mkString(";")}"
  } else {
    s"jdbc:hive2://localhost:$port/$defaultDb?${sessionConf.mkString(";")}"
  }

  override def beforeAll(): Unit = {
    Class.forName(classOf[HiveDriver].getCanonicalName)
    livyConf.set(LivyConf.THRIFT_TRANSPORT_MODE, mode.toString)
    livyConf.set(LivyConf.THRIFT_SERVER_PORT, port)

    // Set formatted Spark and Scala version into livy configuration, this will be used by
    // session creation.
    livyConf.set(LIVY_SPARK_VERSION.key, formattedSparkVersion.productIterator.mkString("."))
    livyConf.set(LIVY_SPARK_SCALA_VERSION.key,
      sparkScalaVersion(formattedSparkVersion, scalaVersionFromSparkSubmit, livyConf))
    StateStore.init(livyConf)

    val ss = new SessionStore(livyConf)
    val sessionManager = new InteractiveSessionManager(livyConf, ss)
    val accessManager = new AccessManager(livyConf)
    LivyThriftServer.start(livyConf, sessionManager, ss, accessManager)
    LivyThriftServer.thriftServerThread.join(THRIFT_SERVER_STARTUP_TIMEOUT)
    assert(LivyThriftServer.getInstance.isDefined)
    assert(LivyThriftServer.getInstance.get.getServiceState == STATE.STARTED)
  }

  override def afterAll(): Unit = {
    LivyThriftServer.stopServer()
  }

  def withJdbcConnection(f: (Connection => Unit)): Unit = {
    withJdbcConnection("default", Seq.empty)(f)
  }

  def withJdbcConnection(db: String, sessionConf: Seq[String])(f: (Connection => Unit)): Unit = {
    withJdbcConnection(jdbcUri(db, sessionConf: _*))(f)
  }

  def withJdbcConnection(uri: String)(f: (Connection => Unit)): Unit = {
    val user = System.getProperty("user.name")
    val connection = DriverManager.getConnection(uri, user, "")
    try {
      f(connection)
    } finally {
      connection.close()
    }
  }

  def withJdbcStatement(f: (Statement => Unit)): Unit = {
    withJdbcConnection { connection =>
      val s = connection.createStatement()
      try {
        f(s)
      } finally {
        s.close()
      }
    }
  }
} 
Example 51
Source File: JdbcSessionImpl.scala    From lagom   with Apache License 2.0 5 votes vote down vote up
package com.lightbend.lagom.internal.scaladsl.persistence.jdbc

import java.sql.Connection

import com.lightbend.lagom.internal.persistence.jdbc.SlickProvider
import com.lightbend.lagom.scaladsl.persistence.jdbc.JdbcSession

import scala.concurrent.Future


final class JdbcSessionImpl(slick: SlickProvider) extends JdbcSession {
  import slick.profile.api._

  override def withConnection[T](block: Connection => T): Future[T] = {
    slick.db.run {
      SimpleDBIO { ctx =>
        block(ctx.connection)
      }
    }
  }

  override def withTransaction[T](block: Connection => T): Future[T] = {
    slick.db.run {
      SimpleDBIO { ctx =>
        block(ctx.connection)
      }.transactionally
    }
  }
} 
Example 52
Source File: JdbcTestEntityReadSide.scala    From lagom   with Apache License 2.0 5 votes vote down vote up
package com.lightbend.lagom.scaladsl.persistence.jdbc

import java.sql.Connection
import com.lightbend.lagom.scaladsl.persistence.ReadSideProcessor.ReadSideHandler
import com.lightbend.lagom.scaladsl.persistence.TestEntity.Evt
import com.lightbend.lagom.scaladsl.persistence.AggregateEventTag
import com.lightbend.lagom.scaladsl.persistence.EventStreamElement
import com.lightbend.lagom.scaladsl.persistence.ReadSideProcessor
import com.lightbend.lagom.scaladsl.persistence.TestEntity

import scala.concurrent.Future

object JdbcTestEntityReadSide {
  class TestEntityReadSideProcessor(readSide: JdbcReadSide) extends ReadSideProcessor[TestEntity.Evt] {
    import JdbcSession.tryWith

    def buildHandler(): ReadSideHandler[TestEntity.Evt] =
      readSide
        .builder[TestEntity.Evt]("test-entity-read-side")
        .setGlobalPrepare(this.createTable)
        .setEventHandler(updateCount _)
        .build()

    private def createTable(connection: Connection): Unit = {
      tryWith(connection.prepareCall("create table if not exists testcounts (id varchar primary key, count bigint)")) {
        _.execute()
      }
    }

    private def updateCount(connection: Connection, event: EventStreamElement[TestEntity.Appended]): Unit = {
      tryWith(connection.prepareStatement("select count from testcounts where id = ?")) { statement =>
        statement.setString(1, event.entityId)
        tryWith(statement.executeQuery) { rs =>
          tryWith(if (rs.next) {
            val count: Long = rs.getLong("count")
            val update      = connection.prepareStatement("update testcounts set count = ? where id = ?")
            update.setLong(1, count + 1)
            update.setString(2, event.entityId)
            update
          } else {
            val update = connection.prepareStatement("insert into testcounts values (?, 1)")
            update.setString(1, event.entityId)
            update
          })(_.execute)
        }
      }
    }

    def aggregateTags: Set[AggregateEventTag[Evt]] = TestEntity.Evt.aggregateEventShards.allTags
  }
}

class JdbcTestEntityReadSide(session: JdbcSession) {
  import JdbcSession.tryWith

  def getAppendCount(id: String): Future[Long] =
    session.withConnection(connection => {
      tryWith(connection.prepareStatement("select count from testcounts where id = ?")) { statement =>
        statement.setString(1, id)

        tryWith(statement.executeQuery()) { rs =>
          if (rs.next()) rs.getLong("count")
          else 0L
        }
      }
    })
} 
Example 53
Source File: PostgresDialect.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.REAL) {
      Some(FloatType)
    } else if (sqlType == Types.SMALLINT) {
      Some(ShortType)
    } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      Some(StringType)
    } else if (sqlType == Types.ARRAY) {
      val scale = md.build.getLong("scale").toInt
      // postgres array type names start with underscore
      toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
    } else None
  }

  private def toCatalystType(
      typeName: String,
      precision: Int,
      scale: Int): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale))
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
    case t: DecimalType => Some(
      JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
      connection.setAutoCommit(false)
    }

  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
} 
Example 54
Source File: JdbcSQLite.scala    From Scientific-Computing-with-Scala   with MIT License 5 votes vote down vote up
import java.sql.DriverManager
import java.sql.Connection

object JdbcSqlite {
  def main(args: Array[String]) {
    var c: Connection = null
    try {
      Class.forName("org.sqlite.JDBC")
      c = DriverManager.getConnection("jdbc:sqlite:planets.sqlite")
    } catch {
      case e: Throwable => e.printStackTrace
    }
    c.close()
  }
} 
Example 55
Source File: H2Sandbox.scala    From redshift-fake-driver   with Apache License 2.0 5 votes vote down vote up
package jp.ne.opt.redshiftfake

import java.sql.{DriverManager, Connection}
import java.util.Properties

import jp.ne.opt.redshiftfake.util.Loan.using
import org.scalatest.{Outcome, fixture}

trait H2Sandbox { self: fixture.TestSuite =>

  type FixtureParam = Connection

  override def withFixture(test: OneArgTest): Outcome = {
    val url = "jdbc:h2redshift:mem:redshift;MODE=PostgreSQL;DATABASE_TO_UPPER=false"
    val prop = new Properties()
    prop.setProperty("driver", "org.h2.jdbc.FakeH2Driver")
    prop.setProperty("user", "sa")

    Class.forName("org.h2.jdbc.FakeH2Driver")
    using(DriverManager.getConnection(url, prop))(test)
  }
} 
Example 56
Source File: MemsqlRDD.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.memsql.spark.SQLGen.VariableList
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}

case class MemsqlRDD(query: String,
                     variables: VariableList,
                     options: MemsqlOptions,
                     schema: StructType,
                     expectedOutput: Seq[Attribute],
                     @transient val sc: SparkContext)
    extends RDD[Row](sc, Nil) {

  override protected def getPartitions: Array[Partition] =
    MemsqlQueryHelpers.GetPartitions(options, query, variables)

  override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = {
    var closed                     = false
    var rs: ResultSet              = null
    var stmt: PreparedStatement    = null
    var conn: Connection           = null
    var partition: MemsqlPartition = rawPartition.asInstanceOf[MemsqlPartition]

    def tryClose(name: String, what: AutoCloseable): Unit = {
      try {
        if (what != null) { what.close() }
      } catch {
        case e: Exception => logWarning(s"Exception closing $name", e)
      }
    }

    def close(): Unit = {
      if (closed) { return }
      tryClose("resultset", rs)
      tryClose("statement", stmt)
      tryClose("connection", conn)
      closed = true
    }

    context.addTaskCompletionListener { context =>
      close()
    }

    conn = JdbcUtils.createConnectionFactory(partition.connectionInfo)()
    stmt = conn.prepareStatement(partition.query)
    JdbcHelpers.fillStatement(stmt, partition.variables)
    rs = stmt.executeQuery()

    var rowsIter = JdbcUtils.resultSetToRows(rs, schema)

    if (expectedOutput.nonEmpty) {
      val schemaDatatypes   = schema.map(_.dataType)
      val expectedDatatypes = expectedOutput.map(_.dataType)

      if (schemaDatatypes != expectedDatatypes) {
        val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map {
          case ((_: StringType, _: NullType), _)     => ((_: Row) => null)
          case ((_: ShortType, _: BooleanType), i)   => ((r: Row) => r.getShort(i) != 0)
          case ((_: IntegerType, _: BooleanType), i) => ((r: Row) => r.getInt(i) != 0)
          case ((_: LongType, _: BooleanType), i)    => ((r: Row) => r.getLong(i) != 0)

          case ((l, r), i) => {
            options.assert(l == r, s"MemsqlRDD: unable to encode ${l} into ${r}")
            ((r: Row) => r.get(i))
          }
        }

        rowsIter = rowsIter
          .map(row => Row.fromSeq(columnEncoders.map(_(row))))
      }
    }

    CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close)
  }

} 
Example 57
Source File: BinaryTypeBenchmark.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import com.memsql.spark.BatchInsertBenchmark.{df, executeQuery}
import org.apache.spark.sql.types.{BinaryType, IntegerType}
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

// BinaryTypeBenchmark is written to writing of the BinaryType with CPU profiler
// this feature is accessible in Ultimate version of IntelliJ IDEA
// see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
object BinaryTypeBenchmark extends App {
  final val masterHost: String = sys.props.getOrElse("memsql.host", "localhost")
  final val masterPort: String = sys.props.getOrElse("memsql.port", "5506")

  val spark: SparkSession = SparkSession
    .builder()
    .master("local")
    .config("spark.sql.shuffle.partitions", "1")
    .config("spark.driver.bindAddress", "localhost")
    .config("spark.datasource.memsql.ddlEndpoint", s"${masterHost}:${masterPort}")
    .config("spark.datasource.memsql.database", "testdb")
    .getOrCreate()

  def jdbcConnection: Loan[Connection] = {
    val connProperties = new Properties()
    connProperties.put("user", "root")

    Loan(
      DriverManager.getConnection(
        s"jdbc:mysql://$masterHost:$masterPort",
        connProperties
      ))
  }

  def executeQuery(sql: String): Unit = {
    jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
  }

  executeQuery("set global default_partitions_per_leaf = 2")
  executeQuery("drop database if exists testdb")
  executeQuery("create database testdb")

  def genRandomByte(): Byte = (Random.nextInt(256) - 128).toByte
  def genRandomRow(): Array[Byte] =
    Array.fill(1000)(genRandomByte())

  val df = spark.createDF(
    List.fill(100000)(genRandomRow()).zipWithIndex,
    List(("data", BinaryType, true), ("id", IntegerType, true))
  )

  val start1 = System.nanoTime()
  df.write
    .format("memsql")
    .mode(SaveMode.Overwrite)
    .save("testdb.LoadData")

  println("Elapsed time: " + (System.nanoTime() - start1) + "ns [LoadData CSV]")

  val start2 = System.nanoTime()
  df.write
    .format("memsql")
    .option("tableKey.primary", "id")
    .option("onDuplicateKeySQL", "id = id")
    .mode(SaveMode.Overwrite)
    .save("testdb.BatchInsert")

  println("Elapsed time: " + (System.nanoTime() - start2) + "ns [BatchInsert]")

  val avroStart = System.nanoTime()
  df.write
    .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
    .mode(SaveMode.Overwrite)
    .option(MemsqlOptions.LOAD_DATA_FORMAT, "Avro")
    .save("testdb.AvroSerialization")
  println("Elapsed time: " + (System.nanoTime() - avroStart) + "ns [LoadData Avro] ")
} 
Example 58
Source File: LoadDataBenchmark.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, Date, DriverManager}
import java.time.{Instant, LocalDate}
import java.util.Properties

import org.apache.spark.sql.types._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

// LoadDataBenchmark is written to test load data with CPU profiler
// this feature is accessible in Ultimate version of IntelliJ IDEA
// see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
object LoadDataBenchmark extends App {
  final val masterHost: String = sys.props.getOrElse("memsql.host", "localhost")
  final val masterPort: String = sys.props.getOrElse("memsql.port", "5506")

  val spark: SparkSession = SparkSession
    .builder()
    .master("local")
    .config("spark.sql.shuffle.partitions", "1")
    .config("spark.driver.bindAddress", "localhost")
    .config("spark.datasource.memsql.ddlEndpoint", s"${masterHost}:${masterPort}")
    .config("spark.datasource.memsql.database", "testdb")
    .getOrCreate()

  def jdbcConnection: Loan[Connection] = {
    val connProperties = new Properties()
    connProperties.put("user", "root")

    Loan(
      DriverManager.getConnection(
        s"jdbc:mysql://$masterHost:$masterPort",
        connProperties
      ))
  }

  def executeQuery(sql: String): Unit = {
    jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
  }

  executeQuery("set global default_partitions_per_leaf = 2")
  executeQuery("drop database if exists testdb")
  executeQuery("create database testdb")

  def genRow(): (Long, Int, Double, String) =
    (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20))
  val df =
    spark.createDF(
      List.fill(1000000)(genRow()),
      List(("LongType", LongType, true),
           ("IntType", IntegerType, true),
           ("DoubleType", DoubleType, true),
           ("StringType", StringType, true))
    )

  val start = System.nanoTime()
  df.write
    .format("memsql")
    .mode(SaveMode.Append)
    .save("testdb.batchinsert")

  val diff = System.nanoTime() - start
  println("Elapsed time: " + diff + "ns [CSV serialization] ")

  executeQuery("truncate testdb.batchinsert")

  val avroStart = System.nanoTime()
  df.write
    .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
    .mode(SaveMode.Append)
    .option(MemsqlOptions.LOAD_DATA_FORMAT, "Avro")
    .save("testdb.batchinsert")
  val avroDiff = System.nanoTime() - avroStart
  println("Elapsed time: " + avroDiff + "ns [Avro serialization] ")
} 
Example 59
Source File: BatchInsertBenchmark.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, Date, DriverManager}
import java.time.LocalDate
import java.util.Properties

import org.apache.spark.sql.types._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

// BatchInsertBenchmark is written to test batch insert with CPU profiler
// this feature is accessible in Ultimate version of IntelliJ IDEA
// see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
object BatchInsertBenchmark extends App {
  final val masterHost: String = sys.props.getOrElse("memsql.host", "localhost")
  final val masterPort: String = sys.props.getOrElse("memsql.port", "5506")

  val spark: SparkSession = SparkSession
    .builder()
    .master("local")
    .config("spark.sql.shuffle.partitions", "1")
    .config("spark.driver.bindAddress", "localhost")
    .config("spark.datasource.memsql.ddlEndpoint", s"${masterHost}:${masterPort}")
    .config("spark.datasource.memsql.database", "testdb")
    .getOrCreate()

  def jdbcConnection: Loan[Connection] = {
    val connProperties = new Properties()
    connProperties.put("user", "root")

    Loan(
      DriverManager.getConnection(
        s"jdbc:mysql://$masterHost:$masterPort",
        connProperties
      ))
  }

  def executeQuery(sql: String): Unit = {
    jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
  }

  executeQuery("set global default_partitions_per_leaf = 2")
  executeQuery("drop database if exists testdb")
  executeQuery("create database testdb")

  def genDate() =
    Date.valueOf(LocalDate.ofEpochDay(LocalDate.of(2001, 4, 11).toEpochDay + Random.nextInt(10000)))
  def genRow(): (Long, Int, Double, String, Date) =
    (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20), genDate())
  val df =
    spark.createDF(
      List.fill(1000000)(genRow()),
      List(("LongType", LongType, true),
           ("IntType", IntegerType, true),
           ("DoubleType", DoubleType, true),
           ("StringType", StringType, true),
           ("DateType", DateType, true))
    )

  val start = System.nanoTime()
  df.write
    .format("memsql")
    .option("tableKey.primary", "IntType")
    .option("onDuplicateKeySQL", "IntType = IntType")
    .mode(SaveMode.Append)
    .save("testdb.batchinsert")

  val diff = System.nanoTime() - start
  println("Elapsed time: " + diff + "ns")
} 
Example 60
Source File: DaoServiceComponent.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import java.sql.{Connection, ResultSet, PreparedStatement}

trait DaoService {
  def getConnection(): Connection
  
  def executeSelect[T](preparedStatement: PreparedStatement)(f: (ResultSet) => List[T]): List[T] =
    try {
      f(preparedStatement.executeQuery())
    } finally {
      preparedStatement.close()
    }

  def readResultSet[T](rs: ResultSet)(f: ResultSet => T): List[T] =
    Iterator.continually((rs.next(), rs)).takeWhile(_._1).map {
      case (_, row) =>
        f(rs)
    }.toList
}

trait DaoServiceComponent {
  this: DatabaseServiceComponent =>
  
  val daoService: DaoService
  
  class DaoServiceImpl extends DaoService {
    override def getConnection(): Connection = databaseService.getConnection
  }
} 
Example 61
package com.ivan.nikolov.scheduler.dao

import java.sql.Connection
import javax.sql.DataSource

import com.ivan.nikolov.scheduler.config.app.AppConfigComponent
import org.h2.jdbcx.JdbcConnectionPool

trait DatabaseService {
  val dbDriver: String
  val connectionString: String
  val username: String
  val password: String
  val ds: DataSource

  def getConnection: Connection = ds.getConnection
}

trait DatabaseServiceComponent {
  this: AppConfigComponent =>
  
  val databaseService: DatabaseService
  
  class H2DatabaseService extends DatabaseService {
    override val dbDriver: String = "org.h2.Driver"
    override val connectionString: String = appConfigService.dbConnectionString
    override val username: String = appConfigService.dbUsername
    override val password: String = appConfigService.dbPassword
    override val ds: DataSource = JdbcConnectionPool.create(connectionString, username, password)
  }
} 
Example 62
Source File: DaoServiceComponent.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import java.sql.{Connection, ResultSet, PreparedStatement}

trait DaoService {
  def getConnection(): Connection
  
  def executeSelect[T](preparedStatement: PreparedStatement)(f: (ResultSet) => List[T]): List[T] =
    try {
      f(preparedStatement.executeQuery())
    } finally {
      preparedStatement.close()
    }

  def readResultSet[T](rs: ResultSet)(f: ResultSet => T): List[T] =
    Iterator.continually((rs.next(), rs)).takeWhile(_._1).map {
      case (_, row) =>
        f(rs)
    }.toList
}

trait DaoServiceComponent {
  this: DatabaseServiceComponent =>
  
  val daoService: DaoService
  
  class DaoServiceImpl extends DaoService {
    override def getConnection(): Connection = databaseService.getConnection
  }
} 
Example 63
package com.ivan.nikolov.scheduler.dao

import java.sql.Connection
import javax.sql.DataSource

import com.ivan.nikolov.scheduler.config.app.AppConfigComponent
import org.h2.jdbcx.JdbcConnectionPool

trait DatabaseService {
  val dbDriver: String
  val connectionString: String
  val username: String
  val password: String
  val ds: DataSource

  def getConnection: Connection = ds.getConnection
}

trait DatabaseServiceComponent {
  this: AppConfigComponent =>
  
  val databaseService: DatabaseService
  
  class H2DatabaseService extends DatabaseService {
    override val dbDriver: String = "org.h2.Driver"
    override val connectionString: String = appConfigService.dbConnectionString
    override val username: String = appConfigService.dbUsername
    override val password: String = appConfigService.dbPassword
    override val ds: DataSource = JdbcConnectionPool.create(connectionString, username, password)
  }
} 
Example 64
Source File: DockerTmpDB.scala    From akka-stream-extensions   with Apache License 2.0 5 votes vote down vote up
package com.mfglabs.stream
package extensions.postgres

import java.sql.{DriverManager, Connection}
import org.postgresql.util.PSQLException
import org.scalatest.{Suite, BeforeAndAfter}
import scala.sys.process._
import scala.util.{Failure, Success, Try}
import com.typesafe.config.ConfigFactory

trait DockerTmpDB extends BeforeAndAfter { self: Suite =>

  import Debug._

  val version: PostgresVersion = PostgresVersion(ConfigFactory.load().getString("postgres.version"))

  Class.forName("org.postgresql.Driver")
  implicit var conn : Connection = _

  val dockerInstances = collection.mutable.Buffer.empty[String]

  def newPGDB(): Int = {
    val port: Int = 5432 + (math.random * (10000 - 5432)).toInt
    Try {
      s"docker pull postgres:${version.value}".pp.!!.trim
      val containerId =
        s"""docker run -p $port:5432 -e POSTGRES_PASSWORD=pwd -d postgres:${version.value}""".pp.!!.trim
      dockerInstances += containerId.pp("New docker instance with id")
      port
    } match {
      case Success(p) => p
      case Failure(err) =>
        throw  new IllegalStateException(s"Error while trying to run docker container", err)
    }
  }

  lazy val dockerIp: String =
    Try("docker-machine ip default".!!.trim).toOption
      .orElse {
        val conf = ConfigFactory.load()
        if (conf.hasPath("docker.ip")) Some(conf.getString("docker.ip")) else None
      }
      .getOrElse("127.0.0.1") // platform dependent

  //ugly solution to wait for the connection to be ready
  def waitsForConnection(port : Int) : Connection = {
    try {
      DriverManager.getConnection(s"jdbc:postgresql://$dockerIp:$port/postgres", "postgres", "pwd")
    } catch {
      case _: PSQLException =>
        println("Retrying DB connection...")
        Thread.sleep(1000)
        waitsForConnection(port)
    }
  }

  before {
    val port = newPGDB()
    println(s"New postgres ${version.value} instance at port $port")
    Thread.sleep(5000)
    conn = waitsForConnection(port)
  }

  after {
    conn.close()
    dockerInstances.toSeq.foreach { dockerId =>
      s"docker stop $dockerId".pp.!!
      s"docker rm $dockerId".pp.!!
    }
  }

}

object Debug {

  implicit class RichString(s:String){
    def pp :String = pp(None)
    def pp(p:String) :String = pp(Some(p))

    private def pp(p:Option[String]) = {
      println(p.map(_ + " ").getOrElse("") + s)
      s
    }
  }
} 
Example 65
Source File: Databases.scala    From eclair   with Apache License 2.0 5 votes vote down vote up
package fr.acinq.eclair.db

import java.io.File
import java.sql.{Connection, DriverManager}

import fr.acinq.eclair.db.sqlite._
import grizzled.slf4j.Logging
import org.sqlite.SQLiteException

trait Databases {

  val network: NetworkDb

  val audit: AuditDb

  val channels: ChannelsDb

  val peers: PeersDb

  val payments: PaymentsDb

  val pendingRelay: PendingRelayDb

  def backup(file: File): Unit
}

object Databases extends Logging {

  
  def sqliteJDBC(dbdir: File): Databases = {
    dbdir.mkdir()
    var sqliteEclair: Connection = null
    var sqliteNetwork: Connection = null
    var sqliteAudit: Connection = null
    try {
      sqliteEclair = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "eclair.sqlite")}")
      sqliteNetwork = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "network.sqlite")}")
      sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}")
      SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file
      logger.info("successful lock on eclair.sqlite")
      databaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair)
    } catch {
      case t: Throwable => {
        logger.error("could not create connection to sqlite databases: ", t)
        if (sqliteEclair != null) sqliteEclair.close()
        if (sqliteNetwork != null) sqliteNetwork.close()
        if (sqliteAudit != null) sqliteAudit.close()
        throw t
      }
    }

  }

  def databaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection) = new Databases {
    override val network = new SqliteNetworkDb(networkJdbc)
    override val audit = new SqliteAuditDb(auditJdbc)
    override val channels = new SqliteChannelsDb(eclairJdbc)
    override val peers = new SqlitePeersDb(eclairJdbc)
    override val payments = new SqlitePaymentsDb(eclairJdbc)
    override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc)

    override def backup(file: File): Unit = {
      SqliteUtils.using(eclairJdbc.createStatement()) {
        statement => {
          statement.executeUpdate(s"backup to ${file.getAbsolutePath}")
        }
      }
    }
  }
} 
Example 66
Source File: SqlitePeersDb.scala    From eclair   with Apache License 2.0 5 votes vote down vote up
package fr.acinq.eclair.db.sqlite

import java.sql.Connection

import fr.acinq.bitcoin.Crypto
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.db.PeersDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{codecSequence, getVersion, using}
import fr.acinq.eclair.wire._
import scodec.bits.BitVector

class SqlitePeersDb(sqlite: Connection) extends PeersDb {

  import SqliteUtils.ExtendedResultSet._

  val DB_NAME = "peers"
  val CURRENT_VERSION = 1

  using(sqlite.createStatement(), inTransaction = true) { statement =>
    require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
    statement.executeUpdate("CREATE TABLE IF NOT EXISTS peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
  }

  override def addOrUpdatePeer(nodeId: Crypto.PublicKey, nodeaddress: NodeAddress): Unit = {
    val data = CommonCodecs.nodeaddress.encode(nodeaddress).require.toByteArray
    using(sqlite.prepareStatement("UPDATE peers SET data=? WHERE node_id=?")) { update =>
      update.setBytes(1, data)
      update.setBytes(2, nodeId.value.toArray)
      if (update.executeUpdate() == 0) {
        using(sqlite.prepareStatement("INSERT INTO peers VALUES (?, ?)")) { statement =>
          statement.setBytes(1, nodeId.value.toArray)
          statement.setBytes(2, data)
          statement.executeUpdate()
        }
      }
    }
  }

  override def removePeer(nodeId: Crypto.PublicKey): Unit = {
    using(sqlite.prepareStatement("DELETE FROM peers WHERE node_id=?")) { statement =>
      statement.setBytes(1, nodeId.value.toArray)
      statement.executeUpdate()
    }
  }

  override def getPeer(nodeId: PublicKey): Option[NodeAddress] = {
    using(sqlite.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement =>
      statement.setBytes(1, nodeId.value.toArray)
      val rs = statement.executeQuery()
      codecSequence(rs, CommonCodecs.nodeaddress).headOption
    }
  }

  override def listPeers(): Map[PublicKey, NodeAddress] = {
    using(sqlite.createStatement()) { statement =>
      val rs = statement.executeQuery("SELECT node_id, data FROM peers")
      var m: Map[PublicKey, NodeAddress] = Map()
      while (rs.next()) {
        val nodeid = PublicKey(rs.getByteVector("node_id"))
        val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value
        m += (nodeid -> nodeaddress)
      }
      m
    }
  }

  // used by mobile apps
  override def close(): Unit = sqlite.close()
} 
Example 67
Source File: SqliteFeeratesDb.scala    From eclair   with Apache License 2.0 5 votes vote down vote up
package fr.acinq.eclair.db.sqlite

import java.sql.Connection

import fr.acinq.eclair.blockchain.fee.FeeratesPerKB
import fr.acinq.eclair.db.FeeratesDb


class SqliteFeeratesDb(sqlite: Connection) extends FeeratesDb {

  import SqliteUtils._

  val DB_NAME = "feerates"
  val CURRENT_VERSION = 1

  using(sqlite.createStatement(), inTransaction = true) { statement =>
    getVersion(statement, DB_NAME, CURRENT_VERSION) match {
      case CURRENT_VERSION =>
        // Create feerates table. Rates are in kb.
        statement.executeUpdate(
          """
            |CREATE TABLE IF NOT EXISTS feerates_per_kb (
            |rate_block_1 INTEGER NOT NULL, rate_blocks_2 INTEGER NOT NULL, rate_blocks_6 INTEGER NOT NULL, rate_blocks_12 INTEGER NOT NULL, rate_blocks_36 INTEGER NOT NULL, rate_blocks_72 INTEGER NOT NULL, rate_blocks_144 INTEGER NOT NULL,
            |timestamp INTEGER NOT NULL)""".stripMargin)
      case unknownVersion => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
    }
  }

  override def addOrUpdateFeerates(feeratesPerKB: FeeratesPerKB): Unit = {
    using(sqlite.prepareStatement("UPDATE feerates_per_kb SET rate_block_1=?, rate_blocks_2=?, rate_blocks_6=?, rate_blocks_12=?, rate_blocks_36=?, rate_blocks_72=?, rate_blocks_144=?, timestamp=?")) { update =>
      update.setLong(1, feeratesPerKB.block_1)
      update.setLong(2, feeratesPerKB.blocks_2)
      update.setLong(3, feeratesPerKB.blocks_6)
      update.setLong(4, feeratesPerKB.blocks_12)
      update.setLong(5, feeratesPerKB.blocks_36)
      update.setLong(6, feeratesPerKB.blocks_72)
      update.setLong(7, feeratesPerKB.blocks_144)
      update.setLong(8, System.currentTimeMillis())
      if (update.executeUpdate() == 0) {
        using(sqlite.prepareStatement("INSERT INTO feerates_per_kb VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { insert =>
          insert.setLong(1, feeratesPerKB.block_1)
          insert.setLong(2, feeratesPerKB.blocks_2)
          insert.setLong(3, feeratesPerKB.blocks_6)
          insert.setLong(4, feeratesPerKB.blocks_12)
          insert.setLong(5, feeratesPerKB.blocks_36)
          insert.setLong(6, feeratesPerKB.blocks_72)
          insert.setLong(7, feeratesPerKB.blocks_144)
          insert.setLong(8, System.currentTimeMillis())
          insert.executeUpdate()
        }
      }
    }
  }

  override def getFeerates(): Option[FeeratesPerKB] = {
    using(sqlite.prepareStatement("SELECT rate_block_1, rate_blocks_2, rate_blocks_6, rate_blocks_12, rate_blocks_36, rate_blocks_72, rate_blocks_144 FROM feerates_per_kb")) { statement =>
      val rs = statement.executeQuery()
      if (rs.next()) {
        Some(FeeratesPerKB(
          block_1 = rs.getLong("rate_block_1"),
          blocks_2 = rs.getLong("rate_blocks_2"),
          blocks_6 = rs.getLong("rate_blocks_6"),
          blocks_12 = rs.getLong("rate_blocks_12"),
          blocks_36 = rs.getLong("rate_blocks_36"),
          blocks_72 = rs.getLong("rate_blocks_72"),
          blocks_144 = rs.getLong("rate_blocks_144")))
      } else {
        None
      }
    }
  }

  // used by mobile apps
  override def close(): Unit = sqlite.close()
} 
Example 68
Source File: SqlitePendingRelayDb.scala    From eclair   with Apache License 2.0 5 votes vote down vote up
package fr.acinq.eclair.db.sqlite

import java.sql.Connection

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.{Command, HasHtlcId}
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.wire.CommandCodecs.cmdCodec

import scala.collection.immutable.Queue

class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb {

  import SqliteUtils.ExtendedResultSet._
  import SqliteUtils._

  val DB_NAME = "pending_relay"
  val CURRENT_VERSION = 1

  using(sqlite.createStatement(), inTransaction = true) { statement =>
    require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
    // note: should we use a foreign key to local_channels table here?
    statement.executeUpdate("CREATE TABLE IF NOT EXISTS pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))")
  }

  override def addPendingRelay(channelId: ByteVector32, cmd: Command with HasHtlcId): Unit = {
    using(sqlite.prepareStatement("INSERT OR IGNORE INTO pending_relay VALUES (?, ?, ?)")) { statement =>
      statement.setBytes(1, channelId.toArray)
      statement.setLong(2, cmd.id)
      statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray)
      statement.executeUpdate()
    }
  }

  override def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit = {
    using(sqlite.prepareStatement("DELETE FROM pending_relay WHERE channel_id=? AND htlc_id=?")) { statement =>
      statement.setBytes(1, channelId.toArray)
      statement.setLong(2, htlcId)
      statement.executeUpdate()
    }
  }

  override def listPendingRelay(channelId: ByteVector32): Seq[Command with HasHtlcId] = {
    using(sqlite.prepareStatement("SELECT data FROM pending_relay WHERE channel_id=?")) { statement =>
      statement.setBytes(1, channelId.toArray)
      val rs = statement.executeQuery()
      codecSequence(rs, cmdCodec)
    }
  }

  override def listPendingRelay(): Set[(ByteVector32, Long)] = {
    using(sqlite.prepareStatement("SELECT channel_id, htlc_id FROM pending_relay")) { statement =>
      val rs = statement.executeQuery()
      var q: Queue[(ByteVector32, Long)] = Queue()
      while (rs.next()) {
        q = q :+ (rs.getByteVector32("channel_id"), rs.getLong("htlc_id"))
      }
      q.toSet
    }
  }

  // used by mobile apps
  override def close(): Unit = sqlite.close()
} 
Example 69
Source File: DoobieConnectionIOEffect.scala    From eff   with MIT License 5 votes vote down vote up
package org.atnos.eff.addon.doobie

import java.sql.Connection

import _root_.doobie.Transactor
import _root_.doobie.free.connection.ConnectionIO
import cats.effect.Bracket
import cats.implicits._
import cats.~>
import org.atnos.eff._
import org.atnos.eff.all._

trait DoobieConnectionIOTypes {
  type _connectionIO[R] = ConnectionIO |= R
  type _ConnectionIO[R] = ConnectionIO <= R
}

trait DoobieConnectionIOCreation extends DoobieConnectionIOTypes {
  final def fromConnectionIO[R: _connectionIO, A](a: ConnectionIO[A]): Eff[R, A] =
    send[ConnectionIO, R, A](a)
}

trait DoobieConnectionIOInterpretation extends DoobieConnectionIOTypes {

  def runConnectionIO[R, U, F[_], E, A, B](e: Eff[R, A])(t: Transactor[F])(
    implicit mc: Member.Aux[ConnectionIO, R, U],
             mf: F /= U,
             me: Bracket[F, Throwable] ): Eff[U, A] = {

    def getConnection: Eff[U, Connection] =
      send[F, U, Connection](t.connect(t.kernel).allocated.map(_._1))

    def runEffect(connection: Connection): Eff[U, A] =
      interpret.translate(e)(new Translate[ConnectionIO, U] {
        def apply[X](c: ConnectionIO[X]): Eff[U, X] = {
          send[F, U, X](c.foldMap(t.interpret).run(connection))
        }
      })

    def interceptErrors[Y](effect: Eff[U, Y])(oops: F[Unit]): Eff[U, Y] =
      interpret.interceptNat(effect)(new (F ~> F) {
        def apply[X](f: F[X]): F[X] =
          f.handleErrorWith((err: Throwable) => oops *> me.raiseError[X](err))
      })

    getConnection.flatMap { connection =>
      lazy val always: F[Unit] =
        t.strategy.always.foldMap(t.interpret).run(connection)

      lazy val oops: F[Unit] =
        t.strategy.oops.foldMap(t.interpret).run(connection)

      val before: Eff[U, Unit] =
        send(t.strategy.before.foldMap(t.interpret).run(connection))

      val after: Eff[U, Unit] =
        send(t.strategy.after.foldMap(t.interpret).run(connection))

      interceptErrors(before >> runEffect(connection) << after)(oops).addLast(send(always))
    }
  }
}

object DoobieConnectionIOCreation extends DoobieConnectionIOCreation

object DoobieConnectionIOInterpretation extends DoobieConnectionIOInterpretation

trait DoobieConnectionIOEffect extends DoobieConnectionIOCreation with DoobieConnectionIOInterpretation

object DoobieConnectionIOEffect extends DoobieConnectionIOEffect 
Example 70
Source File: JdbcUtil.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.Locale

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


object JdbcUtil {

  def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  def makeSetter(
    conn: Connection,
    dialect: JdbcDialect,
    dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val strValue = row.get(pos) match {
          case str: UTF8String => str.toString
          case str: String => str
        }
        stmt.setString(pos + 1, strValue)

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase(Locale.ROOT).split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
} 
Example 71
Source File: NetezzaRDD.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza

import java.sql.Connection
import java.util.Properties

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.{Partition, SparkContext, TaskContext}


  override def compute(thePart: Partition, context: TaskContext): Iterator[Row] =
    new Iterator[Row] {
      var closed = false
      var finished = false
      var gotNext = false
      var nextValue: Row = null

      context.addTaskCompletionListener { context => close() }
      val part = thePart.asInstanceOf[NetezzaPartition]
      val conn = getConnection()
      val reader = new NetezzaDataReader(conn, table, columns, filters, part, schema)
      reader.startExternalTableDataUnload()

      def getNext(): Row = {
        if (reader.hasNext) {
          reader.next()
        } else {
          finished = true
          null.asInstanceOf[Row]
        }
      }

      def close() {
        if (closed) return
        try {
          if (null != reader) {
            reader.close()
          }
        } catch {
          case e: Exception => logWarning("Exception closing Netezza record reader", e)
        }
        try {
          if (null != conn) {
            conn.close()
          }
          logInfo("closed connection")
        } catch {
          case e: Exception => logWarning("Exception closing connection", e)
        }
      }

      override def hasNext: Boolean = {
        if (!finished) {
          if (!gotNext) {
            nextValue = getNext()
            if (finished) {
              close()
            }
            gotNext = true
          }
        }
        !finished
      }

      override def next(): Row = {
        if (!hasNext) {
          throw new NoSuchElementException("End of stream")
        }
        gotNext = false
        nextValue
      }
    }
} 
Example 72
Source File: IntegrationSuiteBase.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza.integration

import java.sql.Connection

import com.ibm.spark.netezza.NetezzaJdbcUtils
import com.typesafe.config.ConfigFactory
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, DataFrame, SQLContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory

trait IntegrationSuiteBase extends FunSuite with BeforeAndAfterAll with QueryTest{
  private val log = LoggerFactory.getLogger(getClass)

  protected var sc: SparkContext = _
  protected var sqlContext: SQLContext = _
  protected var conn: Connection = _
  protected val prop = new java.util.Properties

  // Configurable vals
  protected var configFile = "application"
  protected var testURL: String = _
  protected var testTable: String = _
  protected var user: String = _
  protected var password: String = _
  protected var numPartitions: Int = _
  protected var sampleDbmaxNumTables: Int = _

  override def beforeAll(): Unit = {
    super.beforeAll()

    sc = new SparkContext("local[*]", "IntegrationTest", new SparkConf())
    sqlContext = new SQLContext(sc)

    val conf = ConfigFactory.load(configFile)
    testURL = conf.getString("test.integration.dbURL")
    testTable = conf.getString("test.integration.table")
    user = conf.getString("test.integration.user")
    password = conf.getString("test.integration.password")
    numPartitions = conf.getInt("test.integration.partition.number")
    sampleDbmaxNumTables = conf.getInt("test.integration.max.numtables")
    prop.setProperty("user", user)
    prop.setProperty("password", password)
    log.info("Attempting to get connection from" + testURL)
    conn = NetezzaJdbcUtils.getConnector(testURL, prop)()
    log.info("got connection.")
  }

  override def afterAll(): Unit = {
    try {
      sc.stop()
    }
    finally {
      conn.close()
      super.afterAll()
    }
  }

  
  def withTable(tableNames: String*)(f: => Unit): Unit = {
    try f finally {
      tableNames.foreach { name =>
        executeJdbcStmt(s"DROP TABLE $name")
      }
    }
  }
} 
Example 73
Source File: JdbcConnector.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors

import java.sql.{Connection, Statement}

import com.aol.one.dwh.infra.sql.Setting
import com.aol.one.dwh.infra.sql.pool.HikariConnectionPool
import com.aol.one.dwh.infra.util.LogTrait
import com.aol.one.dwh.infra.sql.Query
import com.aol.one.dwh.infra.sql.pool.SqlSource.{PRESTO, VERTICA}
import com.facebook.presto.jdbc.PrestoConnection
import com.google.common.cache.CacheBuilder
import com.vertica.jdbc.VerticaConnection
import org.apache.commons.dbutils.ResultSetHandler
import resource.managed

import scala.concurrent.duration._
import scala.util.Try
import scalacache.guava.GuavaCache
import scalacache.memoization._
import scalacache.{CacheConfig, ScalaCache}


abstract class JdbcConnector(@cacheKeyExclude pool: HikariConnectionPool) extends LogTrait {

  implicit val scalaCache = ScalaCache(
    GuavaCache(CacheBuilder.newBuilder().maximumSize(100).build[String, Object]),
    cacheConfig = CacheConfig(keyPrefix = Some(pool.getName))
  )

  def runQuery[V](query: Query, @cacheKeyExclude handler: ResultSetHandler[V]): V = memoizeSync(50.seconds) {
    val rm =
      for {
        connection <- managed(pool.getConnection)
        statement  <- managed(connection.createStatement())
      } yield {
        applySettings(connection, statement, query.settings)
        logger.info(s"Running query:[${query.sql}] source:[${query.source}] settings:[${query.settings.mkString(",")}]")
        val resultSet = statement.executeQuery(query.sql)
        handler.handle(resultSet)
      }

    Try(rm.acquireAndGet(identity)).getOrElse(throw new RuntimeException(s"Failure:[$query]"))
  }

  private def applySettings(connection: Connection, statement: Statement, settings: Seq[Setting]) = {
    settings.foreach(setting => applySetting(connection, statement, setting))
  }

  def applySetting(connection: Connection, statement: Statement, setting: Setting)

}

object JdbcConnector {

  private class PrestoConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
    override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {
      connection.unwrap(classOf[PrestoConnection]).setSessionProperty(setting.key, setting.value)
    }
  }

  private class VerticaConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
    override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {
      connection.unwrap(classOf[VerticaConnection]).setProperty(setting.key, setting.value)
    }
  }

  def apply(connectorType: String, connectionPool: HikariConnectionPool): JdbcConnector = connectorType match {
    case VERTICA => new VerticaConnector(connectionPool)
    case PRESTO => new PrestoConnector(connectionPool)
    case _ => throw new IllegalArgumentException(s"Can't create connector for SQL source:[$connectorType]")
  }
} 
Example 74
Source File: JdbcConnectorTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors

import java.sql.{Connection, DatabaseMetaData, ResultSet, Statement}

import com.aol.one.dwh.infra.config._
import com.aol.one.dwh.infra.sql.pool.HikariConnectionPool
import com.aol.one.dwh.infra.sql.{ListStringResultHandler, Setting, VerticaMaxValuesQuery}
import org.apache.commons.dbutils.ResultSetHandler
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class JdbcConnectorTest extends FunSuite with MockitoSugar {

  private val statement = mock[Statement]
  private val resultSet = mock[ResultSet]
  private val connectionPool = mock[HikariConnectionPool]
  private val connection = mock[Connection]
  private val databaseMetaData = mock[DatabaseMetaData]
  private val resultSetHandler = mock[ResultSetHandler[Long]]
  private val listStringResultHandler = mock[ListStringResultHandler]

  test("check run query result for numeric batch_id column") {
    val resultValue = 100L
    val table = Table("table", List("column"), None)
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT MAX(column) AS column FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(resultSetHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, resultSetHandler)

    assert(result == resultValue)
  }

  test("check run query result for date/time partitions") {
    val resultValue = Some(20190924L)
    val table = Table("table", List("year", "month", "day"), Some(List("yyyy", "MM", "dd")))
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT DISTINCT year, month, day FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(listStringResultHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, listStringResultHandler)

    assert(result == resultValue)
  }
}

class DefaultJdbcConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
  override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {}
} 
Example 75
Source File: TimeBasedDataService.scala    From kafka-jdbc-connector   with Apache License 2.0 5 votes vote down vote up
package com.agoda.kafka.connector.jdbc.services

import java.sql.{Connection, PreparedStatement, ResultSet, Timestamp}
import java.util.{Date, GregorianCalendar, TimeZone}

import com.agoda.kafka.connector.jdbc.JdbcSourceConnectorConstants
import com.agoda.kafka.connector.jdbc.models.DatabaseProduct
import com.agoda.kafka.connector.jdbc.models.DatabaseProduct.{MsSQL, MySQL}
import com.agoda.kafka.connector.jdbc.models.Mode.TimestampMode
import com.agoda.kafka.connector.jdbc.utils.DataConverter
import org.apache.kafka.connect.data.Schema
import org.apache.kafka.connect.source.SourceRecord

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.util.Try


case class TimeBasedDataService(databaseProduct: DatabaseProduct,
                                storedProcedureName: String,
                                batchSize: Int,
                                batchSizeVariableName: String,
                                timestampVariableName: String,
                                var timestampOffset: Long,
                                timestampFieldName: String,
                                topic: String,
                                keyFieldOpt: Option[String],
                                dataConverter: DataConverter,
                                calendar: GregorianCalendar = new GregorianCalendar(TimeZone.getTimeZone("UTC"))
                               ) extends DataService {

  override def createPreparedStatement(connection: Connection): Try[PreparedStatement] = Try {
    val preparedStatement = databaseProduct match {
      case MsSQL => connection.prepareStatement(s"EXECUTE $storedProcedureName @$timestampVariableName = ?, @$batchSizeVariableName = ?")
      case MySQL => connection.prepareStatement(s"CALL $storedProcedureName (@$timestampVariableName := ?, @$batchSizeVariableName := ?)")
    }
    preparedStatement.setTimestamp(1, new Timestamp(timestampOffset), calendar)
    preparedStatement.setObject(2, batchSize)
    preparedStatement
  }

  override def extractRecords(resultSet: ResultSet, schema: Schema): Try[Seq[SourceRecord]] = Try {
    val sourceRecords = ListBuffer.empty[SourceRecord]
    var max = timestampOffset
    while (resultSet.next()) {
      dataConverter.convertRecord(schema, resultSet) map { record =>
        val time = record.get(timestampFieldName).asInstanceOf[Date].getTime
        max = if(time > max) {
          keyFieldOpt match {
            case Some(keyField) =>
              sourceRecords += new SourceRecord(
                Map(JdbcSourceConnectorConstants.STORED_PROCEDURE_NAME_KEY -> storedProcedureName).asJava,
                Map(TimestampMode.entryName -> time).asJava, topic, null, schema, record.get(keyField), schema, record
              )
            case None           =>
              sourceRecords += new SourceRecord(
                Map(JdbcSourceConnectorConstants.STORED_PROCEDURE_NAME_KEY -> storedProcedureName).asJava,
                Map(TimestampMode.entryName -> time).asJava, topic, schema, record
              )
          }
          time
        } else max
      }
    }
    timestampOffset = max
    sourceRecords
  }

  override def toString: String = {
    s"""
       |{
       |   "name" : "${this.getClass.getSimpleName}"
       |   "mode" : "${TimestampMode.entryName}"
       |   "stored-procedure.name" : "$storedProcedureName"
       |}
    """.stripMargin
  }
} 
Example 76
Source File: DataService.scala    From kafka-jdbc-connector   with Apache License 2.0 5 votes vote down vote up
package com.agoda.kafka.connector.jdbc.services

import java.sql.{Connection, PreparedStatement, ResultSet}

import com.agoda.kafka.connector.jdbc.utils.DataConverter
import org.apache.kafka.connect.data.Schema
import org.apache.kafka.connect.source.SourceRecord

import scala.concurrent.duration.Duration
import scala.util.Try

trait DataService {


  def getRecords(connection: Connection, timeout: Duration): Try[Seq[SourceRecord]] = {
    for {
      preparedStatement <- createPreparedStatement(connection)
      resultSet         <- executeStoredProcedure(preparedStatement, timeout)
      schema            <- dataConverter.convertSchema(storedProcedureName, resultSet.getMetaData)
      records           <- extractRecords(resultSet, schema)
    } yield records
  }

  protected def createPreparedStatement(connection: Connection): Try[PreparedStatement]

  protected def extractRecords(resultSet: ResultSet, schema: Schema): Try[Seq[SourceRecord]]

  private def executeStoredProcedure(preparedStatement: PreparedStatement, timeout: Duration): Try[ResultSet] = Try {
    preparedStatement.setQueryTimeout(timeout.toSeconds.toInt)
    preparedStatement.executeQuery
  }
} 
Example 77
Source File: DataServiceTest.scala    From kafka-jdbc-connector   with Apache License 2.0 5 votes vote down vote up
package com.agoda.kafka.connector.jdbc.services

import java.sql.{Connection, PreparedStatement, ResultSet, ResultSetMetaData}

import com.agoda.kafka.connector.jdbc.utils.DataConverter
import org.apache.kafka.connect.data.Schema
import org.apache.kafka.connect.source.SourceRecord
import org.scalatest.mockito.MockitoSugar
import org.mockito.Mockito._
import org.scalatest.{Matchers, WordSpec}

import scala.concurrent.duration._
import scala.util.Success

class DataServiceTest extends WordSpec with Matchers with MockitoSugar {

  "Data Service" should {

    val spName = "stored-procedure"
    val connection = mock[Connection]
    val converter = mock[DataConverter]
    val sourceRecord1 = mock[SourceRecord]
    val sourceRecord2 = mock[SourceRecord]
    val resultSet = mock[ResultSet]
    val resultSetMetadata = mock[ResultSetMetaData]
    val preparedStatement = mock[PreparedStatement]
    val schema = mock[Schema]

    val dataService = new DataService {

      override def storedProcedureName: String = spName

      override protected def createPreparedStatement(connection: Connection) = Success(preparedStatement)

      override protected def extractRecords(resultSet: ResultSet, schema: Schema) = Success(Seq(sourceRecord1, sourceRecord2))

      override def dataConverter: DataConverter = converter
    }

    "get records" in {
      doNothing().when(preparedStatement).setQueryTimeout(1)
      when(preparedStatement.executeQuery).thenReturn(resultSet)
      when(resultSet.getMetaData).thenReturn(resultSetMetadata)
      when(converter.convertSchema(spName, resultSetMetadata)).thenReturn(Success(schema))

      dataService.getRecords(connection, 1.second) shouldBe Success(Seq(sourceRecord1, sourceRecord2))

      verify(preparedStatement).setQueryTimeout(1)
      verify(preparedStatement).executeQuery
      verify(resultSet).getMetaData
      verify(converter).convertSchema(spName, resultSetMetadata)
    }
  }
} 
Example 78
Source File: DbUtils.scala    From osmesa   with Apache License 2.0 5 votes vote down vote up
package osmesa.apps

import java.net.URI
import java.sql.Connection

import vectorpipe.util.DBUtils

object DbUtils {
  
  def saveLocations(procName: String, sequence: Int, databaseURI: URI) = {
    var connection: Connection = null
    try {
      connection = DBUtils.getJdbcConnection(databaseURI)
      val upsertSequence =
        connection.prepareStatement(
          """
            |INSERT INTO checkpoints (proc_name, sequence)
            |VALUES (?, ?)
            |ON CONFLICT (proc_name)
            |DO UPDATE SET sequence = ?
          """.stripMargin
        )
      upsertSequence.setString(1, procName)
      upsertSequence.setInt(2, sequence)
      upsertSequence.setInt(3, sequence)
      upsertSequence.execute()
    } finally {
      if (connection != null) connection.close()
    }
  }
} 
Example 79
Source File: PrestoContainer.scala    From testcontainers-scala   with MIT License 5 votes vote down vote up
package com.dimafeng.testcontainers

import java.sql.Connection

import org.testcontainers.containers.{PrestoContainer => JavaPrestoContainer}

case class PrestoContainer(
  dockerImageName: String = PrestoContainer.defaultDockerImageName,
  dbUsername: String = PrestoContainer.defaultDbUsername,
  dbName: String = PrestoContainer.defaultDbName,
  commonJdbcParams: JdbcDatabaseContainer.CommonParams = JdbcDatabaseContainer.CommonParams()
) extends SingleContainer[JavaPrestoContainer[_]] with JdbcDatabaseContainer {

  override val container: JavaPrestoContainer[_] = {
    val c = new JavaPrestoContainer(dockerImageName)
    c.withUsername(dbUsername)
    c.withDatabaseName(dbName)
    commonJdbcParams.applyTo(c)
    c
  }

  def testQueryString: String = container.getTestQueryString

  def createConnection: Connection = container.createConnection()
}

object PrestoContainer {

  val defaultDockerImageName = s"${JavaPrestoContainer.IMAGE}:${JavaPrestoContainer.DEFAULT_TAG}"
  val defaultDbUsername = "test"
  val defaultDbName = ""

  case class Def(
    dockerImageName: String = PrestoContainer.defaultDockerImageName,
    dbUsername: String = PrestoContainer.defaultDbUsername,
    dbName: String = PrestoContainer.defaultDbName,
    commonJdbcParams: JdbcDatabaseContainer.CommonParams = JdbcDatabaseContainer.CommonParams()
  ) extends ContainerDef {

    override type Container = PrestoContainer

    override def createContainer(): PrestoContainer = {
      new PrestoContainer(
        dockerImageName,
        dbUsername,
        dbName,
        commonJdbcParams
      )
    }
  }
} 
Example 80
Source File: EmbeddedPostgreSQL.scala    From akka-http-microservice-templates   with MIT License 5 votes vote down vote up
package utils

import java.nio.file.Paths
import java.util

import ru.yandex.qatools.embed.postgresql.distribution.Version.V9_6_8

object EmbeddedPostgreSQL {
  import ru.yandex.qatools.embed.postgresql.EmbeddedPostgres

  val postgres = new EmbeddedPostgres(V9_6_8)

  def start = {
    val url: String = postgres.start(EmbeddedPostgres.cachedRuntimeConfig(Paths.get("/tmp/postgres")), "localhost", 5432, "users", "user", "password", util.Arrays.asList())

    import java.sql.{Connection, DriverManager}

    Class.forName("org.postgresql.Driver")

    val conn: Connection = DriverManager.getConnection(url)
    
    conn.createStatement().execute(
      """
        CREATE SEQUENCE public.users_id_seq
             INCREMENT 1
             START 1
             MINVALUE 1
             MAXVALUE 9223372036854775807
             CACHE 1;
      """)

    conn.createStatement().execute("""ALTER SEQUENCE public.users_id_seq OWNER TO "user";""")

    conn.createStatement().execute(
      """
      CREATE TABLE public.users (id integer NOT NULL DEFAULT nextval('users_id_seq'::regclass),
                      username character varying(255) COLLATE pg_catalog."default" NOT NULL,
                      age integer NOT NULL,
                      CONSTRAINT users_pkey PRIMARY KEY (id))
                     WITH (
                         OIDS = FALSE
                     )
                     TABLESPACE pg_default;
                     
      ALTER TABLE public.users OWNER to "user";
    """)
  }

  def stop =
    postgres.stop()
} 
Example 81
Source File: EmbeddedPostgreSQL.scala    From akka-http-microservice-templates   with MIT License 5 votes vote down vote up
package utils

import java.nio.file.Paths
import java.util

import ru.yandex.qatools.embed.postgresql.distribution.Version.V9_6_8

object EmbeddedPostgreSQL {
  import ru.yandex.qatools.embed.postgresql.EmbeddedPostgres

  val postgres = new EmbeddedPostgres(V9_6_8)

  def start = {
    val url: String = postgres.start(EmbeddedPostgres.cachedRuntimeConfig(Paths.get("/tmp/postgres")), "localhost", 5432, "users", "user", "password", util.Arrays.asList())

    import java.sql.{Connection, DriverManager}

    Class.forName("org.postgresql.Driver")

    val conn: Connection = DriverManager.getConnection(url)
    
    conn.createStatement().execute(
      """
        CREATE SEQUENCE public.users_id_seq
             INCREMENT 1
             START 1
             MINVALUE 1
             MAXVALUE 9223372036854775807
             CACHE 1;
      """)

    conn.createStatement().execute("""ALTER SEQUENCE public.users_id_seq OWNER TO "user";""")

    conn.createStatement().execute(
      """
      CREATE TABLE public.users (id integer NOT NULL DEFAULT nextval('users_id_seq'::regclass),
                      username character varying(255) COLLATE pg_catalog."default" NOT NULL,
                      user_age integer NOT NULL,
                      CONSTRAINT users_pkey PRIMARY KEY (id))
                     WITH (
                         OIDS = FALSE
                     )
                     TABLESPACE pg_default;
                     
      ALTER TABLE public.users OWNER to "user";
    """)
  }

  def stop =
    postgres.stop()
} 
Example 82
Source File: H2Utils.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.testing.utils

import java.sql.{Connection, DriverManager, ResultSet, Statement}

import org.apache.spark.sql._
import org.opencypher.morpheus.api.io.sql.SqlDataSourceConfig

object H2Utils {

  implicit class ConnOps(conn: Connection) {
    def run[T](code: Statement => T): T = {
      val stmt = conn.createStatement()
      try { code(stmt) } finally { stmt.close() }
    }
    def execute(sql: String): Boolean = conn.run(_.execute(sql))
    def query(sql: String): ResultSet = conn.run(_.executeQuery(sql))
    def update(sql: String): Int = conn.run(_.executeUpdate(sql))
  }

  def withConnection[T](cfg: SqlDataSourceConfig.Jdbc)(code: Connection => T): T = {
    Class.forName(cfg.driver)
    val conn = (cfg.options.get("user"), cfg.options.get("password")) match {
      case (Some(user), Some(pass)) =>
        DriverManager.getConnection(cfg.url, user, pass)
      case _ =>
        DriverManager.getConnection(cfg.url)
    }
    try { code(conn) } finally { conn.close() }
  }

  implicit class DataFrameWriterOps(write: DataFrameWriter[Row]) {
    def maybeOption(key: String, value: Option[String]): DataFrameWriter[Row] =
      value.fold(write)(write.option(key, _))
  }

  implicit class DataFrameSqlOps(df: DataFrame) {

    def saveAsSqlTable(cfg: SqlDataSourceConfig.Jdbc, tableName: String): Unit =
      df.write
        .mode(SaveMode.Overwrite)
        .format("jdbc")
        .option("url", cfg.url)
        .option("driver", cfg.driver)
        .options(cfg.options)
        .option("dbtable", tableName)
        .save()
  }
} 
Example 83
Source File: PostgresDialect.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.REAL) {
      Some(FloatType)
    } else if (sqlType == Types.SMALLINT) {
      Some(ShortType)
    } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      Some(StringType)
    } else if (sqlType == Types.ARRAY) {
      val scale = md.build.getLong("scale").toInt
      // postgres array type names start with underscore
      toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
    } else None
  }

  private def toCatalystType(
      typeName: String,
      precision: Int,
      scale: Int): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale))
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
    case t: DecimalType => Some(
      JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
      connection.setAutoCommit(false)
    }

  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
} 
Example 84
Source File: KsqlDriver.scala    From ksql-jdbc-driver   with Apache License 2.0 5 votes vote down vote up
package com.github.mmolimar.ksql.jdbc

import java.sql.{Connection, Driver, DriverPropertyInfo}
import java.util.Properties
import java.util.logging.Logger

import com.github.mmolimar.ksql.jdbc.Exceptions._

import scala.util.matching.Regex

object KsqlDriver {

  val ksqlName = "ksqlDB"
  val ksqlPrefix = "jdbc:ksql://"

  val driverName = "ksqlDB JDBC driver"
  val driverMajorVersion = 1
  val driverMinorVersion = 2
  val driverVersion = s"$driverMajorVersion.$driverMinorVersion"

  val jdbcMajorVersion = 4
  val jdbcMinorVersion = 1

  val ksqlMajorVersion = 5
  val ksqlMinorVersion = 4
  val ksqlMicroVersion = 0
  val ksqlVersion = s"$ksqlMajorVersion.$ksqlMinorVersion.$ksqlMicroVersion"

  private val ksqlUserPassRegex = "((.+):(.+)@){0,1}"
  private val ksqlServerRegex = "([A-Za-z0-9._%+-]+):([0-9]{1,5})"
  private val ksqlPropsRegex = "(\\?([A-Za-z0-9._-]+=[A-Za-z0-9._-]+(&[A-Za-z0-9._-]+=[A-Za-z0-9._-]+)*)){0,1}"

  val urlRegex: Regex = s"$ksqlPrefix$ksqlUserPassRegex$ksqlServerRegex$ksqlPropsRegex\\z".r

  def parseUrl(url: String): KsqlConnectionValues = url match {
    case urlRegex(_, username, password, ksqlServer, port, _, props, _) =>
      KsqlConnectionValues(
        ksqlServer,
        port.toInt,
        Option(username),
        Option(password),
        Option(props).map(_.split("&").map(_.split("=")).map(p => p(0) -> p(1)).toMap).getOrElse(Map.empty)
      )
    case _ => throw InvalidUrl(url)
  }
}

class KsqlDriver extends Driver {

  override def acceptsURL(url: String): Boolean = Option(url).exists(_.startsWith(KsqlDriver.ksqlPrefix))

  override def jdbcCompliant: Boolean = false

  override def getPropertyInfo(url: String, info: Properties): scala.Array[DriverPropertyInfo] = scala.Array.empty

  override def getMinorVersion: Int = KsqlDriver.driverMinorVersion

  override def getMajorVersion: Int = KsqlDriver.driverMajorVersion

  override def getParentLogger: Logger = throw NotSupported("getParentLogger")

  override def connect(url: String, properties: Properties): Connection = {
    if (!acceptsURL(url)) throw InvalidUrl(url)

    val connection = buildConnection(KsqlDriver.parseUrl(url), properties)
    connection.validate()
    connection
  }

  private[jdbc] def buildConnection(values: KsqlConnectionValues, properties: Properties): KsqlConnection = {
    new KsqlConnection(values, properties)
  }
} 
Example 85
Source File: KsqlConnectionSpec.scala    From ksql-jdbc-driver   with Apache License 2.0 5 votes vote down vote up
package com.github.mmolimar.ksql.jdbc

import java.sql.{Connection, SQLException, SQLFeatureNotSupportedException}
import java.util.{Collections, Properties}

import com.github.mmolimar.ksql.jdbc.utils.TestUtils._
import io.confluent.ksql.rest.client.{KsqlRestClient, MockableKsqlRestClient, RestResponse}
import io.confluent.ksql.rest.entity._
import org.eclipse.jetty.http.HttpStatus.Code
import org.scalamock.scalatest.MockFactory
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class KsqlConnectionSpec extends AnyWordSpec with Matchers with MockFactory {

  "A KsqlConnection" when {

    "validating specs" should {
      val values = KsqlConnectionValues("localhost", 8080, None, None, Map.empty[String, String])
      val mockKsqlRestClient = mock[MockableKsqlRestClient]
      val ksqlConnection = new KsqlConnection(values, new Properties) {
        override def init: KsqlRestClient = mockKsqlRestClient
      }

      "throw not supported exception if not supported" in {
        val methods = implementedMethods[KsqlConnection]
        reflectMethods[KsqlConnection](methods = methods, implemented = false, obj = ksqlConnection)
          .foreach(method => {
            assertThrows[SQLFeatureNotSupportedException] {
              method()
            }
          })
      }

      "work if implemented" in {
        assertThrows[SQLException] {
          ksqlConnection.isClosed
        }
        ksqlConnection.getTransactionIsolation should be(Connection.TRANSACTION_NONE)
        ksqlConnection.setClientInfo(new Properties)

        (mockKsqlRestClient.makeKsqlRequest(_: String)).expects(*)
          .returns(RestResponse.successful[KsqlEntityList](Code.OK, new KsqlEntityList))
        ksqlConnection.setClientInfo("", "")
        assertThrows[SQLException] {
          (mockKsqlRestClient.makeKsqlRequest(_: String)).expects(*)
            .returns(RestResponse.erroneous(Code.INTERNAL_SERVER_ERROR, new KsqlErrorMessage(-1, "", Collections.emptyList[String])))
          ksqlConnection.setClientInfo("", "")
        }

        ksqlConnection.isReadOnly should be(false)

        (mockKsqlRestClient.makeStatusRequest _: () => RestResponse[CommandStatuses]).expects
          .returns(RestResponse.successful[CommandStatuses]
            (Code.OK, new CommandStatuses(Collections.emptyMap[CommandId, CommandStatus.Status])))
        ksqlConnection.isValid(0) should be(true)

        Option(ksqlConnection.getMetaData) should not be None

        Option(ksqlConnection.createStatement) should not be None
        assertThrows[SQLFeatureNotSupportedException] {
          ksqlConnection.createStatement(-1, -1)
        }
        ksqlConnection.setAutoCommit(true)
        ksqlConnection.setAutoCommit(false)
        ksqlConnection.getAutoCommit should be(false)
        ksqlConnection.getSchema should be(None.orNull)
        ksqlConnection.getWarnings should be(None.orNull)
        ksqlConnection.getCatalog should be(None.orNull)
        ksqlConnection.setCatalog("test")
        ksqlConnection.getCatalog should be(None.orNull)

        (mockKsqlRestClient.close _).expects
        ksqlConnection.close()
        ksqlConnection.isClosed should be(true)
        ksqlConnection.commit()
      }
    }
  }

  "A ConnectionNotSupported" when {

    "validating specs" should {

      "throw not supported exception if not supported" in {

        val resultSet = new ConnectionNotSupported
        reflectMethods[ConnectionNotSupported](methods = Seq.empty, implemented = false, obj = resultSet)
          .foreach(method => {
            assertThrows[SQLFeatureNotSupportedException] {
              method()
            }
          })
      }
    }
  }

} 
Example 86
Source File: DefaultJdbcSchemaReader.scala    From quill   with Apache License 2.0 5 votes vote down vote up
package io.getquill.codegen.jdbc.gen

import java.sql.{ Connection, ResultSet }

import io.getquill.codegen.jdbc.DatabaseTypes.{ DatabaseType, Oracle }
import io.getquill.codegen.jdbc.model.JdbcTypes.{ JdbcConnectionMaker, JdbcSchemaReader }
import io.getquill.codegen.model.{ JdbcColumnMeta, JdbcTableMeta, RawSchema }
import io.getquill.codegen.util.StringUtil._
import io.getquill.util.Using
import scala.util.{ Success, Failure }

import scala.annotation.tailrec
import scala.collection.immutable.List

class DefaultJdbcSchemaReader(
  databaseType: DatabaseType
) extends JdbcSchemaReader {

  @tailrec
  private def resultSetExtractor[T](rs: ResultSet, extractor: (ResultSet) => T, acc: List[T] = List()): List[T] = {
    if (!rs.next())
      acc.reverse
    else
      resultSetExtractor(rs, extractor, extractor(rs) :: acc)
  }

  private[getquill] def schemaPattern(schema: String) =
    databaseType match {
      case Oracle => schema // Oracle meta fetch takes minutes to hours if schema is not specified
      case _      => null
    }

  def jdbcEntityFilter(ts: JdbcTableMeta) =
    ts.tableType.existsInSetNocase("table", "view", "user table", "user view", "base table")

  private[getquill] def extractTables(connectionMaker: () => Connection): List[JdbcTableMeta] = {
    val output = Using.Manager { use =>
      val conn = use(connectionMaker())
      val schema = conn.getSchema
      val rs = use {
        conn.getMetaData.getTables(
          null,
          schemaPattern(schema),
          null,
          null
        )
      }
      resultSetExtractor(rs, rs => JdbcTableMeta.fromResultSet(rs))
    }
    val unfilteredJdbcEntities =
      output match {
        case Success(value) => value
        case Failure(e)     => throw e
      }

    unfilteredJdbcEntities.filter(jdbcEntityFilter(_))
  }

  private[getquill] def extractColumns(connectionMaker: () => Connection): List[JdbcColumnMeta] = {
    val output = Using.Manager { use =>
      val conn = use(connectionMaker())
      val schema = conn.getSchema
      val rs = use {
        conn.getMetaData.getColumns(
          null,
          schemaPattern(schema),
          null,
          null
        )
      }
      resultSetExtractor(rs, rs => JdbcColumnMeta.fromResultSet(rs))
    }
    output match {
      case Success(value) => value
      case Failure(e)     => throw e
    }
  }

  override def apply(connectionMaker: JdbcConnectionMaker): Seq[RawSchema[JdbcTableMeta, JdbcColumnMeta]] = {
    val tableMap =
      extractTables(connectionMaker)
        .map(t => ((t.tableCat, t.tableSchem, t.tableName), t))
        .toMap

    val columns = extractColumns(connectionMaker)
    val tableColumns =
      columns
        .groupBy(c => (c.tableCat, c.tableSchem, c.tableName))
        .map({ case (tup, cols) => tableMap.get(tup).map(RawSchema(_, cols)) })
        .collect({ case Some(tbl) => tbl })

    tableColumns.toSeq
  }
} 
Example 87
Source File: StreamResultsOrBlowUpSpec.scala    From quill   with Apache License 2.0 5 votes vote down vote up
package io.getquill.integration

import java.sql.{ Connection, ResultSet }

import io.getquill._
import io.getquill.context.monix.Runner
import monix.execution.Scheduler
import monix.execution.schedulers.CanBlock
import org.scalatest.matchers.should.Matchers._

import scala.concurrent.duration.Duration


class StreamResultsOrBlowUpSpec extends Spec {

  case class Person(name: String, age: Int)

  private implicit val scheduler = Scheduler.io()

  // set to true in order to create a ResultSet type (i.e. a rewindable one)
  // that will force jdbc to load the entire ResultSet into memory and crash this test.
  val doBlowUp = false

  val ctx = new PostgresMonixJdbcContext(Literal, "testPostgresDB", Runner.default) {
    override protected def prepareStatementForStreaming(sql: String, conn: Connection, fetchSize: Option[Int]) = {
      val stmt =
        conn.prepareStatement(
          sql,
          if (doBlowUp) ResultSet.TYPE_SCROLL_SENSITIVE
          else ResultSet.TYPE_FORWARD_ONLY,
          ResultSet.CONCUR_READ_ONLY
        )
      fetchSize.foreach(stmt.setFetchSize(_))
      stmt
    }
  }
  import ctx.{ run => runQuill, _ }

  val numRows = 1000000L

  "stream a large result set without blowing up" in {
    val deletes = runQuill { query[Person].delete }
    deletes.runSyncUnsafe(Duration.Inf)(scheduler, CanBlock.permit)

    val inserts = quote {
      (numRows: Long) =>
        infix"""insert into person (name, age) select md5(random()::text), random()*10+1 from generate_series(1, ${numRows}) s(i)""".as[Insert[Int]]
    }

    runQuill(inserts(lift(numRows))).runSyncUnsafe(Duration.Inf)(scheduler, CanBlock.permit)

    // not sure why but foreachL causes a OutOfMemory exception anyhow, and firstL causes a ResultSet Closed exception
    val result = stream(query[Person], 100)
      .zipWithIndex
      .foldLeftL(0L)({
        case (totalYears, (person, index)) => {
          // Need to print something out as we stream or travis will thing the build is stalled and kill it with the following message:
          // "No output has been received in the last 10m0s..."
          if (index % 10000 == 0) println(s"Streaming Test Row: ${index}")
          totalYears + person.age
        }
      })
      .runSyncUnsafe(Duration.Inf)(scheduler, CanBlock.permit)
    result should be > numRows
  }
} 
Example 88
Source File: PrepareMonixJdbcSpecBase.scala    From quill   with Apache License 2.0 5 votes vote down vote up
package io.getquill

import java.sql.{ Connection, PreparedStatement, ResultSet }

import io.getquill.context.jdbc.ResultSetExtractor
import io.getquill.context.sql.ProductSpec
import monix.eval.Task
import org.scalactic.Equality

trait PrepareMonixJdbcSpecBase extends ProductSpec {

  implicit val productEq = new Equality[Product] {
    override def areEqual(a: Product, b: Any): Boolean = b match {
      case Product(_, desc, sku) => desc == a.description && sku == a.sku
      case _                     => false
    }
  }

  def productExtractor: ResultSet => Product

  def withOrderedIds(products: List[Product]) =
    products.zipWithIndex.map { case (product, id) => product.copy(id = id.toLong + 1) }

  def singleInsert(conn: => Connection)(prep: Connection => Task[PreparedStatement]) = {
    Task(conn).bracket { conn =>
      prep(conn).bracket { stmt =>
        Task(stmt.execute())
      }(stmt => Task(stmt.close()))
    }(conn => Task(conn.close()))
  }

  def batchInsert(conn: => Connection)(prep: Connection => Task[List[PreparedStatement]]) = {
    Task(conn).bracket { conn =>
      prep(conn).flatMap(stmts =>
        Task.sequence(
          stmts.map(stmt =>
            Task(stmt).bracket { stmt =>
              Task(stmt.execute())
            }(stmt => Task(stmt.close())))
        ))
    }(conn => Task(conn.close()))
  }

  def extractResults[T](conn: => Connection)(prep: Connection => Task[PreparedStatement])(extractor: ResultSet => T) = {
    Task(conn).bracket { conn =>
      prep(conn).bracket { stmt =>
        Task(stmt.executeQuery()).bracket { rs =>
          Task(ResultSetExtractor(rs, extractor))
        }(rs => Task(rs.close()))
      }(stmt => Task(stmt.close()))
    }(conn => Task(conn.close()))
  }

  def extractProducts(conn: => Connection)(prep: Connection => Task[PreparedStatement]) =
    extractResults(conn)(prep)(productExtractor)
} 
Example 89
Source File: PrepareJdbcSpecBase.scala    From quill   with Apache License 2.0 5 votes vote down vote up
package io.getquill.context.jdbc
import java.sql.{ Connection, PreparedStatement, ResultSet }

import io.getquill.context.sql.ProductSpec
import io.getquill.util.Using.Manager
import org.scalactic.Equality
import scala.util.{ Success, Failure }

trait PrepareJdbcSpecBase extends ProductSpec {

  implicit val productEq = new Equality[Product] {
    override def areEqual(a: Product, b: Any): Boolean = b match {
      case Product(_, desc, sku) => desc == a.description && sku == a.sku
      case _                     => false
    }
  }

  def productExtractor: ResultSet => Product

  def withOrderedIds(products: List[Product]) =
    products.zipWithIndex.map { case (product, id) => product.copy(id = id.toLong + 1) }

  def singleInsert(conn: => Connection)(prep: Connection => PreparedStatement) = {
    val flag = Manager { use =>
      val c = use(conn)
      val s = use(prep(c))
      s.execute()
    }
    flag match {
      case Success(value) => value
      case Failure(e)     => throw e
    }
  }

  def batchInsert(conn: => Connection)(prep: Connection => List[PreparedStatement]) = {
    val r = Manager { use =>
      val c = use(conn)
      val st = prep(c)
      appendExecuteSequence(st)
    }
    r.flatten match {
      case Success(value) => value
      case Failure(e)     => throw e
    }
  }

  def extractResults[T](conn: => Connection)(prep: Connection => PreparedStatement)(extractor: ResultSet => T) = {
    val r = Manager { use =>
      val c = use(conn)
      val st = use(prep(c))
      val rs = st.executeQuery()
      ResultSetExtractor(rs, extractor)
    }
    r match {
      case Success(v) => v
      case Failure(e) => throw e
    }
  }

  def extractProducts(conn: => Connection)(prep: Connection => PreparedStatement): List[Product] =
    extractResults(conn)(prep)(productExtractor)

  def appendExecuteSequence(actions: => List[PreparedStatement]) = {
    Manager { use =>
      actions.map { stmt =>
        val s = use(stmt)
        s.execute()
      }
    }
  }
} 
Example 90
Source File: SQLCheck.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.checks.SQLChecks

import java.sql.Connection

import it.agilelab.bigdata.DataQuality.checks.{CheckResult, CheckUtil}
import it.agilelab.bigdata.DataQuality.sources.DatabaseConfig

import scala.util.Try


case class SQLCheck(
    id: String,
    description: String,
    subType: String,
    source: String,
    sourceConfig: DatabaseConfig,
    query: String,
    date: String // opt
) {

  def executeCheck(connection: Connection): CheckResult = {
    val transformations = SQLCheckProcessor.getTransformations(subType)

    val statement = connection.createStatement()
    statement.setFetchSize(1000)

    val queryResult = statement.executeQuery(query)

    val result = transformations._1(queryResult)
    statement.close()

    val status = CheckUtil.tryToStatus(Try(result), transformations._2)

    val cr =
      CheckResult(
        this.id,
        subType,
        this.description,
        this.source,
        "",
        Some(""),
        0.0,
        status.stringValue,
        this.query,
        this.date
      )

    cr
  }

} 
Example 91
Source File: PostgresReader.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.utils.io.db.readers

import java.sql.{Connection, DriverManager, ResultSet}

import it.agilelab.bigdata.DataQuality.sources.DatabaseConfig


case class PostgresReader(config: DatabaseConfig) extends TableReader {

  override val connectionUrl: String = "jdbc:postgresql://" + config.host

  override def runQuery[T](query: String,
                           transformOutput: ResultSet => T): T = {
    val connection = getConnection

    val statement = connection.createStatement()
    statement.setFetchSize(1000)

    val queryResult = statement.executeQuery(query)
    val result = transformOutput(queryResult)
    statement.close()
    result
  }

  override def getConnection: Connection = {
    val connectionProperties = new java.util.Properties()
    config.user match {
      case Some(user) => connectionProperties.put("user", user)
      case None       =>
    }
    config.password match {
      case Some(pwd) => connectionProperties.put("password", pwd)
      case None      =>
    }
    connectionProperties.put("driver", "org.postgresql.Driver")

    DriverManager.getConnection(connectionUrl, connectionProperties)
  }

} 
Example 92
Source File: TableReader.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.utils.io.db.readers

import java.sql.{Connection, ResultSet}
import java.util.Properties

import org.apache.spark.sql.{DataFrame, SQLContext}


  def loadData(
      table: String,
      username: Option[String],
      password: Option[String])(implicit sqlContext: SQLContext): DataFrame = {
    val connectionProperties = new Properties()

    (username, password) match {
      case (Some(u), Some(p)) =>
        connectionProperties.put("user", u)
        connectionProperties.put("password", p)
      case _ =>
    }

    sqlContext.read.jdbc(connectionUrl, table, connectionProperties)
  }

} 
Example 93
Source File: DatabaseConfig.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.sources

import java.sql.Connection

import com.typesafe.config.Config
import it.agilelab.bigdata.DataQuality.exceptions.IllegalParameterException
import it.agilelab.bigdata.DataQuality.utils
import it.agilelab.bigdata.DataQuality.utils.io.db.readers.{ORCLReader, PostgresReader, SQLiteReader, TableReader}
import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.util.Try


case class DatabaseConfig(
                           id: String,
                           subtype: String,
                           host: String,
                           port: Option[String],
                           service: Option[String],
                           user: Option[String],
                           password: Option[String],
                           schema: Option[String]
                         ) {

  // Constructor for
  def this(config: Config) = {
    this(
      Try(config.getString("id")).getOrElse(""),
      config.getString("subtype"),
      config.getString("host"),
      Try(config.getString("port")).toOption,
      Try(config.getString("service")).toOption,
      Try(config.getString("user")).toOption,
      Try(config.getString("password")).toOption,
      Try(config.getString("schema")).toOption
    )
  }

  private val dbReader: TableReader = subtype match {
    case "ORACLE"   => ORCLReader(this)
    case "SQLITE"   => SQLiteReader(this)
    case "POSTGRES" => PostgresReader(this)
    case x          => throw IllegalParameterException(x)
  }

  def getConnection: Connection = dbReader.getConnection
  def getUrl: String = dbReader.getUrl
  // the trick here is that table credentials can be different from database one,
  // so that function allow you to connect to the database with multiple credentials
  // without specification of multiple databases
  def loadData(table: String,
               user: Option[String] = this.user,
               password: Option[String] = this.password)(
                implicit sqlContext: SQLContext): DataFrame =
    dbReader.loadData(utils.makeTableName(schema, table), user, password)
} 
Example 94
Source File: PostgresDialect.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.REAL) {
      Some(FloatType)
    } else if (sqlType == Types.SMALLINT) {
      Some(ShortType)
    } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      Some(StringType)
    } else if (sqlType == Types.ARRAY) {
      val scale = md.build.getLong("scale").toInt
      // postgres array type names start with underscore
      toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
    } else None
  }

  private def toCatalystType(
      typeName: String,
      precision: Int,
      scale: Int): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale))
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT))
    case t: DecimalType => Some(
      JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  
  override def getTruncateQuery(table: String): String = {
    s"TRUNCATE TABLE ONLY $table"
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
      connection.setAutoCommit(false)
    }
  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
} 
Example 95
Source File: ExportStep.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.compiler.sql

import java.sql.Connection

import ingraph.compiler.sql.Util.withResources
import ingraph.compiler.sql.driver.ValueJsonConversion
import org.neo4j.driver.v1.{Session, Value}

import scala.collection.JavaConverters._

class ExportStep(val exportCypherQuery: String, val tableName: String) {

  def exportToTable(cypherSession: Session, sqlConnection: Connection): Unit = {
    val cypherResult = cypherSession.run(exportCypherQuery)

    val keysInRecord = cypherResult.keys.size
    val valueParameters = cypherResult.keys.asScala
      .map(key =>
        if (key == "value")
          "?::jsonb"
        else
          "?")
      .mkString(", ")
    val insertQueryString = s"INSERT INTO $tableName VALUES ($valueParameters)"

    withResources(sqlConnection.prepareStatement(insertQueryString))(insertStatement => {
      for (cypherRecord <- cypherResult.asScala) {
        for (keyIndex <- 0 until keysInRecord) {
          val columnIndex = keyIndex + 1
          val cypherValue = cypherRecord.get(keyIndex)
          val cypherValueObject = cypherValue.asObject

          val value =
            if (cypherRecord.keys().get(keyIndex) == "value")
              ValueJsonConversion.gson.toJson(cypherValue, classOf[Value])
            else
              cypherValueObject

          insertStatement.setObject(columnIndex, value)
        }
        insertStatement.addBatch()
      }

      insertStatement.executeBatch()
    })
  }
} 
Example 96
Source File: ExportSteps.scala    From ingraph   with Eclipse Public License 1.0 5 votes vote down vote up
package ingraph.compiler.sql

import java.sql.Connection

import org.neo4j.driver.v1.Session

object ExportSteps {
  private val exportVertex = new ExportStep(
    """// vertex
      |MATCH (n)
      |RETURN id(n) AS vertex_id""".stripMargin, "vertex")

  private val exportEdge = new ExportStep(
    """// edge
      |MATCH (from)-[edge]->(to)
      |RETURN id(edge) AS edge_id, id(from) AS from, id(to) AS to, type(edge) AS type""".stripMargin, "edge")

  private val exportLabel = new ExportStep(
    """// label
      |MATCH (n)
      |UNWIND labels(n) AS name
      |RETURN id(n) AS parent, name""".stripMargin, "label")

  private val exportVertex_property = new ExportStep(
    """// vertex_property
      |MATCH (n)
      |UNWIND keys(n) AS key
      |RETURN id(n) AS parent, key, properties(n)[key] AS value""".stripMargin, "vertex_property")

  private val exportEdge_property = new ExportStep(
    """// edge_property
      |MATCH ()-[e]->()
      |UNWIND keys(e) AS key
      |RETURN id(e) AS parent, key, properties(e)[key] AS value""".stripMargin, "edge_property")

  private val steps = Array(exportVertex, exportEdge, exportLabel, exportVertex_property, exportEdge_property)

  def execute(cypherSession: Session, sqlConnection: Connection): Unit = {
    for (step <- steps)
      step.exportToTable(cypherSession, sqlConnection)
  }
} 
Example 97
Source File: RichConnection.scala    From s4ds   with Apache License 2.0 5 votes vote down vote up
// RichConnection.scala

import java.sql.{Connection, ResultSet}

class RichConnection(val underlying:Connection) {

  
  def withQuery[T](query:String)(f:ResultSet => T):T = {
    val statement = underlying.prepareStatement(query)
    val results = statement.executeQuery
    try {
      f(results) // loan the ResultSet to the client
    }
    finally {
      // Ensure all the resources get freed.
      results.close
      statement.close
    }
  }
} 
Example 98
Source File: PostgresDialect.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Types}

import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._


private object PostgresDialect extends JdbcDialect {

  override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
      Some(BinaryType)
    } else if (sqlType == Types.OTHER) {
      toCatalystType(typeName).filter(_ == StringType)
    } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
      toCatalystType(typeName.drop(1)).map(ArrayType(_))
    } else None
  }

  // TODO: support more type names.
  private def toCatalystType(typeName: String): Option[DataType] = typeName match {
    case "bool" => Some(BooleanType)
    case "bit" => Some(BinaryType)
    case "int2" => Some(ShortType)
    case "int4" => Some(IntegerType)
    case "int8" | "oid" => Some(LongType)
    case "float4" => Some(FloatType)
    case "money" | "float8" => Some(DoubleType)
    case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
      Some(StringType)
    case "bytea" => Some(BinaryType)
    case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
    case "date" => Some(DateType)
    case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
    case _ => None
  }

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case StringType => Some(JdbcType("TEXT", Types.CHAR))
    case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
    case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
    case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
    case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
    case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
      getJDBCType(et).map(_.databaseTypeDefinition)
        .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
        .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
    case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt");
    case _ => None
  }

  override def getTableExistsQuery(table: String): String = {
    s"SELECT 1 FROM $table LIMIT 1"
  }

  override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
    super.beforeFetch(connection, properties)

    // According to the postgres jdbc documentation we need to be in autocommit=false if we actually
    // want to have fetchsize be non 0 (all the rows).  This allows us to not have to cache all the
    // rows inside the driver when fetching.
    //
    // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
    //
    if (properties.getOrElse("fetchsize", "0").toInt > 0) {
      connection.setAutoCommit(false)
    }

  }
} 
Example 99
Source File: PostgresIntegrationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.jdbc

import java.sql.Connection
import java.util.Properties

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest

@DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
  override val db = new DatabaseOnDocker {
    override val imageName = "postgres:9.4.5"
    override val env = Map(
      "POSTGRES_PASSWORD" -> "rootpass"
    )
    override val jdbcPort = 5432
    override def getJdbcUrl(ip: String, port: Int): String =
      s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass"
  }

  override def dataPreparation(conn: Connection): Unit = {
    conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
    conn.setCatalog("foo")
    conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
      + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
      + "c10 integer[], c11 text[], c12 real[])").executeUpdate()
    conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
      + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
      + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate()
  }

  test("Type mapping for various types") {
    val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
    val rows = df.collect()
    assert(rows.length == 1)
    val types = rows(0).toSeq.map(x => x.getClass)
    assert(types.length == 13)
    assert(classOf[String].isAssignableFrom(types(0)))
    assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
    assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
    assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
    assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
    assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
    assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
    assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
    assert(classOf[String].isAssignableFrom(types(8)))
    assert(classOf[String].isAssignableFrom(types(9)))
    assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
    assert(classOf[Seq[String]].isAssignableFrom(types(11)))
    assert(classOf[Seq[Double]].isAssignableFrom(types(12)))
    assert(rows(0).getString(0).equals("hello"))
    assert(rows(0).getInt(1) == 42)
    assert(rows(0).getDouble(2) == 1.25)
    assert(rows(0).getLong(3) == 123456789012345L)
    assert(rows(0).getBoolean(4) == false)
    // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
      Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
      Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
    assert(rows(0).getBoolean(7) == true)
    assert(rows(0).getString(8) == "172.16.0.42")
    assert(rows(0).getString(9) == "192.168.0.0/16")
    assert(rows(0).getSeq(10) == Seq(1, 2))
    assert(rows(0).getSeq(11) == Seq("a", null, "b"))
    assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
  }

  test("Basic write test") {
    val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
    // Test only that it doesn't crash.
    df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
    // Test write null values.
    df.select(df.queryExecution.analyzed.output.map { a =>
      Column(Literal.create(null, a.dataType)).as(a.name)
    }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
  }
} 
Example 100
Source File: SqlRequestInvokerActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.sql.actor

import java.sql.Connection

import akka.actor.Props
import akka.pattern.{ask, pipe}
import akka.util.Timeout
import asura.common.actor.BaseActor
import asura.common.util.FutureUtils
import asura.core.CoreConfig
import asura.core.es.model.SqlRequest.SqlRequestBody
import asura.core.sql.actor.MySqlConnectionCacheActor.GetConnectionMessage
import asura.core.sql.{MySqlConnector, SqlConfig, SqlParserUtils}

import scala.concurrent.{ExecutionContext, Future}

class SqlRequestInvokerActor extends BaseActor {

  implicit val ec: ExecutionContext = context.dispatcher
  implicit val timeout: Timeout = CoreConfig.DEFAULT_ACTOR_ASK_TIMEOUT

  val connectionCacheActor = context.actorOf(MySqlConnectionCacheActor.props())

  override def receive: Receive = {
    case requestBody: SqlRequestBody =>
      getResponse(requestBody) pipeTo sender()
    case _ =>
      Future.failed(new RuntimeException("Unknown message type")) pipeTo sender()
  }

  def getResponse(requestBody: SqlRequestBody): Future[Object] = {
    implicit val sqlEc = SqlConfig.SQL_EC
    val futConn = (connectionCacheActor ? GetConnectionMessage(requestBody)).asInstanceOf[Future[Connection]]
    val (isOk, errMsg) = SqlParserUtils.isSelectStatement(requestBody.sql)
    if (null == errMsg) {
      futConn.flatMap(conn => {
        if (isOk) {
          Future {
            MySqlConnector.executeQuery(conn, requestBody.sql)
          }
        } else {
          Future {
            MySqlConnector.executeUpdate(conn, requestBody.sql)
          }
        }
      })
    } else {
      FutureUtils.requestFail(errMsg)
    }
  }

}

object SqlRequestInvokerActor {

  def props() = Props(new SqlRequestInvokerActor())
} 
Example 101
Source File: MySqlConnectionCacheActor.scala    From asura   with MIT License 5 votes vote down vote up
package asura.core.sql.actor

import java.sql.Connection

import akka.actor.Props
import akka.pattern.pipe
import akka.util.Timeout
import asura.common.actor.BaseActor
import asura.common.cache.LRUCache
import asura.core.CoreConfig
import asura.core.es.model.SqlRequest.SqlRequestBody
import asura.core.sql.actor.MySqlConnectionCacheActor.GetConnectionMessage
import asura.core.sql.{MySqlConnector, SqlConfig}

import scala.concurrent.{ExecutionContext, Future}

class MySqlConnectionCacheActor(size: Int) extends BaseActor {

  implicit val ec: ExecutionContext = context.dispatcher
  implicit val timeout: Timeout = CoreConfig.DEFAULT_ACTOR_ASK_TIMEOUT

  private val lruCache = LRUCache[String, Connection](size, (_, conn) => {
    conn.close()
  })

  override def receive: Receive = {
    case GetConnectionMessage(sqlRequest) =>
      getConnection(sqlRequest) pipeTo sender()
    case _ =>
      Future.failed(new RuntimeException("Unknown message type")) pipeTo sender()
  }

  private def getConnection(request: SqlRequestBody): Future[Connection] = {
    Future {
      val key = generateCacheKey(request)
      val conn = lruCache.get(key)
      if (null == conn || !conn.isValid(SqlConfig.SQL_CONN_CHECK_TIMEOUT)) {
        val newConn = MySqlConnector.connect(request)
        lruCache.put(key, newConn)
        newConn
      } else {
        conn
      }
    }(SqlConfig.SQL_EC)
  }

  private def generateCacheKey(request: SqlRequestBody): String = {
    val sb = StringBuilder.newBuilder
    sb.append(request.username).append(":")
      .append(request.encryptedPass).append("@")
      .append(request.host).append(":")
      .append(request.port).append("/")
      .append(request.database)
    sb.toString()
  }
}

object MySqlConnectionCacheActor {

  def props(size: Int = SqlConfig.DEFAULT_MYSQL_CONNECTOR_CACHE_SIZE) = Props(new MySqlConnectionCacheActor(size))

  case class GetConnectionMessage(request: SqlRequestBody)

} 
Example 102
Source File: Database.scala    From lighthouse   with Apache License 2.0 5 votes vote down vote up
package be.dataminded.lighthouse.common

import java.sql.{Connection, DriverManager}

import be.dataminded.lighthouse.datalake._

class Database(val driverClassName: String, url: String, properties: Map[String, String] = Map.empty) {

  def withConnection[A](autoCommit: Boolean)(block: (Connection) => A): A = {
    val connection = createConnection(autoCommit)
    try {
      block(connection)
    } finally {
      connection.close()
    }
  }

  def withConnection[A](block: (Connection) => A): A = withConnection(autoCommit = true)(block)

  private def createConnection(autoCommit: Boolean): Connection = {
    Class.forName(driverClassName)
    val connection = DriverManager.getConnection(url, properties)
    connection.setAutoCommit(autoCommit)
    connection
  }
}

object Database {

  def apply(driver: String, url: String, properties: Map[String, String] = Map.empty): Database =
    new Database(driver, url, properties)

  def inMemory(name: String, urlOptions: Map[String, String] = Map.empty): Database = {
    val urlExtra = urlOptions.map { case (k, v) => s"$k=$v" }.mkString(";", ";", "")
    val url      = s"jdbc:h2:mem:$name$urlExtra;"
    new Database("org.h2.Driver", url)
  }
} 
Example 103
Source File: HiveJDBCUtils.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.hive.utilities

import java.security._
import java.sql.{Connection, DriverManager, Statement}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.UserGroupInformation

import com.paypal.gimel.common.conf.{GimelConstants, GimelProperties}

object HiveJDBCUtils {

  def apply(conf: GimelProperties, cluster: String): HiveJDBCUtils = {
    new HiveJDBCUtils(conf, cluster)
  }
}

class HiveJDBCUtils(val props: GimelProperties, cluster: String = "unknown_cluster") {
  val logger = com.paypal.gimel.logger.Logger()

  logger.info("Using Supplied KeyTab to authenticate KDC...")
  val conf = new Configuration
  conf.set(GimelConstants.SECURITY_AUTH, "kerberos")
  UserGroupInformation.setConfiguration(conf)
  val ugi: UserGroupInformation = UserGroupInformation.loginUserFromKeytabAndReturnUGI(props.principal, props.keytab)
  UserGroupInformation.setLoginUser(ugi)


  
  def withStatement(fn: Statement => Any): Any = {
    def MethodName: String = new Exception().getStackTrace.apply(1).getMethodName

    logger.info(" @Begin --> " + MethodName)
    withConnection {
      connection =>
        val statement = connection.createStatement
        var output: Any = None
        try {
          output = fn(statement)
        } catch {
          case e: Throwable =>
            e.printStackTrace
            throw e
        }
        finally {
          if (!statement.isClosed) {
            statement.close
          }
        }
        output
    }
  }

} 
Example 104
Source File: SingletonConnection.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.jdbc.utilities

import java.sql.{Connection, DriverManager}

object SingletonConnection {

  private var connection: Connection = null

  def getConnection(url: String, username: String, password: String): Connection = synchronized {
    if (connection == null || connection.isClosed) {
      connection = DriverManager.getConnection(url, username, password)
    }
    connection
  }


  def getConnection(jdbcConnectionUtility: JDBCConnectionUtility): Connection = synchronized {
    if (connection == null || connection.isClosed) {
      connection = jdbcConnectionUtility.getJdbcConnectionAndSetQueryBand()
    }
    connection
  }
} 
Example 105
Source File: PushDownJdbcRDD.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.jdbc.utilities

import java.sql.{Connection, ResultSet}

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.JdbcRDD
import org.apache.spark.sql.Row

import com.paypal.gimel.common.utilities.GenericUtils
import com.paypal.gimel.logger.Logger


class PushDownJdbcRDD(sc: SparkContext,
                      getConnection: () => Connection,
                      sql: String,
                      mapRow: ResultSet => Row = PushDownJdbcRDD.resultSetToRow)
  extends JdbcRDD[Row](sc, getConnection, sql, 0, 100, 1, mapRow)
    with Logging {

  override def compute(thePart: Partition,
                       context: TaskContext): Iterator[Row] = {
    val logger = Logger(this.getClass.getName)
    val functionName = s"[QueryHash: ${sql.hashCode}]"
    logger.info(s"Proceeding to execute push down query $functionName: $sql")
    val queryResult: String = GenericUtils.time(functionName, Some(logger)) {
      JDBCConnectionUtility.withResources(getConnection()) { connection =>
        JdbcAuxiliaryUtilities.executeQueryAndReturnResultString(
          sql,
          connection
        )
      }
    }
    Seq(Row(queryResult)).iterator
  }
}

object PushDownJdbcRDD {
  def resultSetToRow(rs: ResultSet): Row = {
    Row(rs.getString(0))
  }
} 
Example 106
Source File: MySQLPoolManager.scala    From spark_mysql   with Apache License 2.0 5 votes vote down vote up
package utils

import java.sql.Connection

import com.mchange.v2.c3p0.ComboPooledDataSource

/**
  * Created with IntelliJ IDEA.
  * Author: [email protected]
  * Description:MySQL连接池管理类
  * Date: Created in 2018-11-17 12:43
  */
object MySQLPoolManager {
  var mysqlManager: MysqlPool = _

  def getMysqlManager: MysqlPool = {
    synchronized {
      if (mysqlManager == null) {
        mysqlManager = new MysqlPool
      }
    }
    mysqlManager
  }

  class MysqlPool extends Serializable {
    private val cpds: ComboPooledDataSource = new ComboPooledDataSource(true)
    try {
      cpds.setJdbcUrl(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.jdbc.url"))
      cpds.setDriverClass(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.pool.jdbc.driverClass"))
      cpds.setUser(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.jdbc.username"))
      cpds.setPassword(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.jdbc.password"))
      cpds.setMinPoolSize(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.pool.jdbc.minPoolSize").toInt)
      cpds.setMaxPoolSize(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.pool.jdbc.maxPoolSize").toInt)
      cpds.setAcquireIncrement(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.pool.jdbc.acquireIncrement").toInt)
      cpds.setMaxStatements(PropertyUtils.getFileProperties("mysql-user.properties", "mysql.pool.jdbc.maxStatements").toInt)
    } catch {
      case e: Exception => e.printStackTrace()
    }

    def getConnection: Connection = {
      try {
        cpds.getConnection()
      } catch {
        case ex: Exception =>
          ex.printStackTrace()
          null
      }
    }

    def close(): Unit = {
      try {
        cpds.close()
      } catch {
        case ex: Exception =>
          ex.printStackTrace()
      }
    }
  }

} 
Example 107
Source File: BaseWriter.scala    From kafka-connect-sap   with Apache License 2.0 5 votes vote down vote up
package com.sap.kafka.connect.sink

import java.sql.Connection
import java.util

import org.apache.kafka.connect.sink.SinkRecord
import org.slf4j.{Logger, LoggerFactory}


abstract class BaseWriter {

 private val log: Logger = LoggerFactory.getLogger(getClass)
 private var connection:Connection = null

  protected[sink] def initializeConnection(): Unit

  protected[sink] def write(records: util.Collection[SinkRecord]): Unit


 private[sink] def close(): Unit = {
   if (connection != null) {
     try {
       connection.close()
       connection = null
     }
     catch {
       case _: Exception => log.warn("Ignoring error closing connection")
     }
   }
 }
} 
Example 108
Source File: HANAWriter.scala    From kafka-connect-sap   with Apache License 2.0 5 votes vote down vote up
package com.sap.kafka.connect.sink.hana

import java.sql.Connection
import java.util

import com.google.common.base.Function
import com.google.common.collect.Multimaps
import com.sap.kafka.client.hana.HANAJdbcClient
import com.sap.kafka.connect.config.hana.HANAConfig
import com.sap.kafka.connect.sink.BaseWriter
import org.apache.kafka.connect.sink.SinkRecord
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConversions._


class HANAWriter(config: HANAConfig, hanaClient: HANAJdbcClient,
                 tableCache: scala.collection.mutable.Map[String, HANASinkRecordsCollector])
  extends BaseWriter {

  private val log: Logger = LoggerFactory.getLogger(getClass)
  private var connection:Connection = null

  override def initializeConnection(): Unit = {
    if(connection == null || connection.isClosed ) {
      connection = hanaClient.getConnection
    }
    else if(!connection.isValid(120))
    {
      connection.close()
      connection = hanaClient.getConnection
    }
    connection.setAutoCommit(false)
  }


  override def write(records: util.Collection[SinkRecord]): Unit = {
    log.info("write records to HANA")
    log.info("initialize connection to HANA")

    initializeConnection()

    val topicMap = Multimaps.index(records, new Function[SinkRecord, String] {
      override def apply(sinkRecord: SinkRecord) = sinkRecord.topic()
    }).asMap().toMap

    for ((topic, recordsPerTopic) <- topicMap) {
      var table = config.topicProperties(topic).get("table.name").get
      if (table.contains("${topic}")) {
        table = table.replace("${topic}", topic)
      }

      val recordsCollector: Option[HANASinkRecordsCollector] = tableCache.get(table)

      recordsCollector match {
        case None =>
          val tableRecordsCollector = new HANASinkRecordsCollector(table, hanaClient, connection, config)
          tableCache.put(table, tableRecordsCollector)
          tableRecordsCollector.add(recordsPerTopic.toSeq)
        case Some(tableRecordsCollector) =>
          if (config.autoSchemaUpdateOn) {
            tableRecordsCollector.tableConfigInitialized = false
          }
          tableRecordsCollector.add(recordsPerTopic.toSeq)
      }
    }
    flush(tableCache.toMap)
    log.info("flushing records to HANA successful")
  }

  private def flush(tableCache: Map[String, HANASinkRecordsCollector]): Unit = {
    log.info("flush records into HANA")
    for ((table, recordsCollector) <- tableCache) {
        recordsCollector.flush()
    }
    hanaClient.commit(connection)
  }

} 
Example 109
Source File: DB.scala    From recogito2   with Apache License 2.0 5 votes vote down vote up
package storage.db

import akka.actor.ActorSystem
import com.google.inject.AbstractModule
import java.sql.Connection
import javax.inject.{ Inject, Singleton }
import services.user.UserService
import services.user.Roles
import org.jooq.impl.DSL
import org.jooq.{ SQLDialect, DSLContext }
import play.api.Logger
import play.api.db.Database
import scala.collection.JavaConversions._
import scala.concurrent.{ ExecutionContext, Future }
import scala.io.Source

object DB {

  val CURRENT_SQLDIALECTT = SQLDialect.POSTGRES_9_4

}


  private def initDB(connection: Connection) = {

    // Splitting by ; is not 100% robust - but should be sufficient for our own schema file
    val statement = connection.createStatement

    Source.fromFile("conf/schema.sql", "UTF-8")
      .getLines().map(_.trim)
      .filter(line => !(line.startsWith("--") || line.isEmpty))
      .mkString(" ").split(";")
      .foreach(s => {
        statement.addBatch(s + ";")
      })

    statement.executeBatch()
    statement.close()
  }
  
  private def createDefaultUserIfEmpty() =
    userService.countUsers.map { count =>
      if (count == 0) {
        Logger.warn("#######################################################")
        Logger.warn("# Empty user table - creating default recogito/recogito")
        Logger.warn("#######################################################")
        
        val f = for {
          _ <- userService.insertUser("recogito", "[email protected]", "recogito", false)
          _ <- userService.insertUserRole("recogito", Roles.Admin)
        } yield()
  
        f.map { _ =>
          Logger.warn("# Done. Make sure to remove this user in production!")
          Logger.warn("#######################################################")
        } recover { case t: Throwable => t.printStackTrace() }
      }
    } recover { case t: Throwable =>
      t.printStackTrace()
    }

} 
Example 110
Source File: JdbcSourceStage.scala    From fusion-data   with Apache License 2.0 5 votes vote down vote up
package mass.connector.sql

import java.sql.{ Connection, PreparedStatement, ResultSet }

import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
import akka.stream.{ Attributes, Outlet, SourceShape }
import javax.sql.DataSource
import fusion.jdbc.ConnectionPreparedStatementCreator
import fusion.jdbc.util.JdbcUtils

import scala.util.control.NonFatal

class JdbcSourceStage(dataSource: DataSource, creator: ConnectionPreparedStatementCreator, fetchRowSize: Int)
    extends GraphStage[SourceShape[ResultSet]] {
  private val out: Outlet[ResultSet] = Outlet("JdbcSource.out")

  override def shape: SourceShape[ResultSet] = SourceShape(out)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) with OutHandler {
      var maybeConn =
        Option.empty[(Connection, Boolean, PreparedStatement, ResultSet)]

      setHandler(out, this)

      override def onPull(): Unit =
        maybeConn match {
          case Some((_, _, _, rs)) if rs.next() =>
            push(out, rs)
          case Some(_) =>
            completeStage()
          case None =>
            () // doing nothing, waiting for in preStart() to be completed
        }

      override def preStart(): Unit =
        try {
          val conn = dataSource.getConnection
          val autoCommit = conn.getAutoCommit
          conn.setAutoCommit(false)
          val stmt = creator(conn)
          val rs = stmt.executeQuery()
          //          rs.setFetchDirection(ResultSet.TYPE_FORWARD_ONLY)
          rs.setFetchSize(fetchRowSize)
          maybeConn = Option((conn, autoCommit, stmt, rs))
        } catch {
          case NonFatal(e) => failStage(e)
        }

      override def postStop(): Unit =
        for {
          (conn, autoCommit, stmt, rs) <- maybeConn
        } {
          JdbcUtils.closeResultSet(rs)
          JdbcUtils.closeStatement(stmt)
          conn.setAutoCommit(autoCommit)
          JdbcUtils.closeConnection(conn)
        }
    }
} 
Example 111
Source File: ConnectionUtils.scala    From azure-sqldb-spark   with MIT License 5 votes vote down vote up
package com.microsoft.azure.sqldb.spark.connect

import java.sql.{Connection, DriverManager, SQLException}
import java.util.Properties

import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}

/**
  * Helper and utility methods used for setting up or using a connection
  */
private[spark] object ConnectionUtils {

  /**
    * Retrieves all connection properties in the Config object
    * and returns them as a [[Properties]] object.
    *
    * @param config the Config object with specified connection properties.
    * @return A connection [[Properties]] object.
    */
  def createConnectionProperties(config: Config): Properties = {
    val connectionProperties = new Properties()
    for (key <- config.getAllKeys) {
      connectionProperties.put(key.toString, config.get[String](key.toString).get)
    }
    connectionProperties
  }

  /**
    * Adds the "jdbc:sqlserver://" suffix to a general server url
    *
    * @param url the string url without the JDBC prefix
    * @return the url with the added JDBC prefix
    */
  def createJDBCUrl(url: String): String = SqlDBConfig.JDBCUrlPrefix + url

  /**
    * Gets a JDBC connection based on Config properties
    *
    * @param config any read or write Config
    * @return a JDBC Connection
    */
  def getConnection(config: Config): Connection = {
    Class.forName(SqlDBConfig.SQLjdbcDriver)
    DriverManager.getConnection(
      createJDBCUrl(config.get[String](SqlDBConfig.URL).get), createConnectionProperties(config))
  }

  /**
    * Retrieves the DBTable or QueryCustom specified in the config.
    * NOTE: only one property can exist within config.
    *
    * @param config the Config object with specified properties.
    * @return The specified DBTable or QueryCustom
    */
  def getTableOrQuery(config: Config): String = {
    config.get[String](SqlDBConfig.DBTable).getOrElse(
      getQueryCustom(config.get[String](SqlDBConfig.QueryCustom).get)
    )
  }

  /**
    * The JDBC driver requires parentheses and a temp variable around any custom queries.
    * This adds the required syntax so users only need to specify the query.
    *
    * @param query the default query
    * @return the syntactically correct query to be executed by the JDBC driver.
    */
  def getQueryCustom(query: String): String = s"($query) QueryCustom"

} 
Example 112
Source File: QueryFunctions.scala    From azure-sqldb-spark   with MIT License 5 votes vote down vote up
package com.microsoft.azure.sqldb.spark.query

import java.sql.{Connection, SQLException}

import com.microsoft.azure.sqldb.spark.connect.ConnectionUtils._
import com.microsoft.azure.sqldb.spark.LoggingTrait
import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
import com.microsoft.azure.sqldb.spark.connect._
import org.apache.spark.sql.{DataFrame, SQLContext}


  def sqlDBQuery(config: Config): Either[DataFrame, Boolean] = {

    var connection: Connection = null

    val sql = config.get[String](SqlDBConfig.QueryCustom).getOrElse(
      throw new IllegalArgumentException("Query not found in QueryCustom in Config")
    )

    try {
      connection = getConnection(config)
      val statement = connection.createStatement()

      if (statement.execute(sql)) {
        Left(sqlContext.read.sqlDB(config))
      }
      else {
        Right(true)
      }
    }
    catch {
      case sqlException: SQLException => {
        sqlException.printStackTrace()
        Right(false)
      }
      case exception: Exception => {
        exception.printStackTrace()
        Right(false)
      }
    }
    finally {
      connection.close()
    }
  }
} 
Example 113
Source File: MySQLUtil.scala    From SqlShift   with MIT License 5 votes vote down vote up
package com.goibibo.sqlshift

import java.net.URL
import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.typesafe.config.Config
import org.slf4j.{Logger, LoggerFactory}

import scala.io.Source


object MySQLUtil {
    private val logger: Logger = LoggerFactory.getLogger(this.getClass)

    private def getMySQLConnection(config: Config): Connection = {
        val mysql = config.getConfig("mysql")
        val connectionProps = new Properties()
        connectionProps.put("user", mysql.getString("username"))
        connectionProps.put("password", mysql.getString("password"))
        val jdbcUrl = s"jdbc:mysql://${mysql.getString("hostname")}:${mysql.getInt("portno")}/${mysql.getString("db")}?createDatabaseIfNotExist=true&useSSL=false"
        Class.forName("com.mysql.jdbc.Driver")
        DriverManager.getConnection(jdbcUrl, connectionProps)
    }

    def createTableAndInsertRecords(config: Config, tableName: String, psvFile: URL): Unit = {
        logger.info("Inserting records in table: {}", tableName)
        val records = Source.fromFile(psvFile.toURI).getLines().toList.drop(1) // removing header

        val conn = getMySQLConnection(config)
        val statement = conn.createStatement()
        try {
            val tableCreateQuery = config.getString("table.tableCreateQuery").replace("${tableName}", tableName)
            logger.info("Running query: {}", tableCreateQuery)
            statement.executeUpdate(tableCreateQuery)
            val insertIntoQuery = config.getString("table.insertIntoQuery").replace("${tableName}", tableName)
            logger.info("Running query: {}", insertIntoQuery)
            records.foreach { record: String =>
                val columns = record.split("\\|")
                val query = insertIntoQuery.format(columns: _*)
                statement.executeUpdate(query)
            }
        } finally {
            statement.close()
            conn.close()
        }
    }
} 
Example 114
Source File: SparkNRedshiftUtil.scala    From SqlShift   with MIT License 5 votes vote down vote up
package com.goibibo.sqlshift

import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.databricks.spark.redshift.RedshiftReaderM
import com.typesafe.config.Config
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, Suite}
import org.slf4j.{Logger, LoggerFactory}


trait SparkNRedshiftUtil extends BeforeAndAfterAll {
    self: Suite =>
    private val logger: Logger = LoggerFactory.getLogger(this.getClass)
    @transient private var _sc: SparkContext = _
    @transient private var _sqlContext: SQLContext = _

    def sc: SparkContext = _sc
    def sqlContext: SQLContext = _sqlContext

    private def getRedshiftConnection(config: Config): Connection = {
        val mysql = config.getConfig("redshift")
        val connectionProps = new Properties()
        connectionProps.put("user", mysql.getString("username"))
        connectionProps.put("password", mysql.getString("password"))
        val jdbcUrl = s"jdbc:redshift://${mysql.getString("hostname")}:${mysql.getInt("portno")}/${mysql.getString("database")}?useSSL=false"
        Class.forName("com.amazon.redshift.jdbc4.Driver")
        DriverManager.getConnection(jdbcUrl, connectionProps)
    }

    val getSparkContext: (SparkContext, SQLContext) = {
        val sparkConf: SparkConf = new SparkConf().setAppName("Full Dump Testing").setMaster("local")
        val sc: SparkContext = new SparkContext(sparkConf)
        val sqlContext: SQLContext = new SQLContext(sc)

        System.setProperty("com.amazonaws.services.s3.enableV4", "true")
        sc.hadoopConfiguration.set("fs.s3a.endpoint", "s3.ap-south-1.amazonaws.com")
        sc.hadoopConfiguration.set("fs.s3a.fast.upload", "true")
        (sc, sqlContext)
    }

    def readTableFromRedshift(config: Config, tableName: String): DataFrame = {
        val redshift: Config = config.getConfig("redshift")
        val options = Map("dbtable" -> tableName,
            "user" -> redshift.getString("username"),
            "password" -> redshift.getString("password"),
            "url" -> s"jdbc:redshift://${redshift.getString("hostname")}:${redshift.getInt("portno")}/${redshift.getString("database")}",
            "tempdir" -> config.getString("s3.location"),
            "aws_iam_role" -> config.getString("redshift.iamRole")
        )
        RedshiftReaderM.getDataFrameForConfig(options, sc, sqlContext)
    }

    def dropTableRedshift(config: Config, tables: String*): Unit = {
        logger.info("Droping table: {}", tables)
        val conn = getRedshiftConnection(config)
        val statement = conn.createStatement()
        try {
            val dropTableQuery = s"""DROP TABLE ${tables.mkString(",")}"""
            logger.info("Running query: {}", dropTableQuery)
            statement.executeUpdate(dropTableQuery)
        } finally {
            statement.close()
            conn.close()
        }
    }

    override protected def beforeAll(): Unit = {
        super.beforeAll()
        val (sc, sqlContext) = getSparkContext
        _sc = sc
        _sqlContext = sqlContext
    }

    override protected def afterAll(): Unit = {
        super.afterAll()
        _sc.stop()
    }
} 
Example 115
Source File: PostgresTaskLogsDao.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres

import java.sql.{Connection, ResultSet}
import java.util.{Date, UUID}

import dao.TaskLogsDao
import model.TaskEventLog

import util.JdbcUtil._

class PostgresTaskLogsDao(implicit conn: Connection) extends TaskLogsDao {

  final val TABLE = "task_log"
  final val COL_ID = "task_log_id"
  final val COL_TASK_ID = "task_id"
  final val COL_WHEN = "when_" // 'when' is a reserved word in PostgreSQL
  final val COL_SOURCE = "source"
  final val COL_MESSAGE = "message"

  private def unmarshal(rs: ResultSet): TaskEventLog = {
    TaskEventLog(
      id = rs.getObject(COL_ID).asInstanceOf[UUID],
      taskId = rs.getObject(COL_TASK_ID).asInstanceOf[UUID],
      when = new Date(rs.getTimestamp(COL_WHEN).getTime()),
      source = rs.getString(COL_SOURCE),
      message = rs.getString(COL_MESSAGE)
    )
  }

  override def loadEventsForTask(taskId: UUID) = {
    val stmt =
      conn.prepareStatement(s"SELECT * FROM $TABLE WHERE $COL_TASK_ID = ?")
    stmt.setObject(1, taskId)
    stmt.executeQuery().map(unmarshal).toList
  }

  override def saveEvents(events: Seq[TaskEventLog]) {
    val sql =
      s"""
         |INSERT INTO $TABLE
         |($COL_ID, $COL_TASK_ID, $COL_WHEN, $COL_SOURCE, $COL_MESSAGE)
         |VALUES
         |(?, ?, ?, ?, ?)
       """.stripMargin
    val stmt = conn.prepareStatement(sql)
    events.foreach { event =>
      stmt.setObject(1, event.id)
      stmt.setObject(2, event.taskId)
      stmt.setTimestamp(3, new java.sql.Timestamp(event.when.getTime))
      stmt.setString(4, event.source)
      stmt.setString(5, event.message)
      stmt.addBatch()
    }
    stmt.executeBatch()
  }

} 
Example 116
Source File: PStatementTest.scala    From yoda-orm   with MIT License 5 votes vote down vote up
package in.norbor.yoda.orm

import java.sql.{Connection, DriverManager, ResultSet, Timestamp}

import com.typesafe.scalalogging.LazyLogging
import in.norbor.yoda.implicits.JavaSqlImprovement._
import mocks.People
import org.joda.time.DateTime
import org.scalatest.funsuite.AnyFunSuite


class PStatementTest extends AnyFunSuite {

  Class.forName("org.h2.Driver")

  private implicit val conn: Connection = DriverManager.getConnection("jdbc:h2:~/test", "sa", "")

  test("0) apply") {

    val ps = PStatement("SELECT 1")(conn)
    assert(ps !== null)

    ps.equals(null)
    ps.canEqual(null)
    ps.hashCode
    ps.toString
    ps.productPrefix
    ps.productArity
    ps.productElement(0)
    ps.productIterator
    ps.copy()
  }

  test("0) query") {

    PStatement("DROP TABLE IF EXISTS yoda_sql; CREATE TABLE yoda_sql (id INTEGER);")
      .update
  }

  test("0) update") {

    val rs = PStatement("""select 1""")
      .query

    assert(rs !== null)
  }

  test("0) queryOne with non index parameter") {

    val result = PStatement("""select ?, ?, ?, ?, ?, ?, ?, ?""")
      .setBoolean(true)
      .setInt(1)
      .setLong(1L)
      .setDouble(1)
      .setString("YO")
      .setDateTime(DateTime.now)
      .setTimestamp(new Timestamp(System.currentTimeMillis))
      .setTimestamp(null)
      .queryOne(parse)

    assert(result.head._1 === true)
  }

  test("3) queryList with parse method") {

    val peoples = PStatement("""select 1 as id, 'Peerapat' as name, now() as born;""")
      .queryList(parsePeople)

    assert(peoples.head.id === 1)
    assert(peoples.head.name === "Peerapat")
    assert(peoples.head.born.getMillis <= DateTime.now.getMillis)
  }

  test("5) batch") {

    val insert = PStatement("INSERT INTO yoda_sql VALUES(?)")
      .setInt(1)
      .addBatch()
      .setInt(2)
      .addBatch()
      .executeBatch

    assert(insert.length === 2)
  }


  private def parse(rs: ResultSet): (Boolean, Int, Long, Double, String, DateTime, Timestamp) = (rs.getBoolean(1)
    , rs.getInt(2)
    , rs.getLong(3)
    , rs.getDouble(4)
    , rs.getString(5)
    , rs.getDateTime(6)
    , rs.getTimestamp(7)
  )

  private def parsePeople(rs: ResultSet): People = People(id = rs.getLong("id")
    , name = rs.getString("name")
    , born = rs.getDateTime("born")
  )

} 
Example 117
Source File: CustomerTimerDemo.scala    From flink-rookie   with Apache License 2.0 5 votes vote down vote up
package com.venn.stream.api.timer

import java.io.File
import java.sql.{Connection, DriverManager, PreparedStatement, SQLException}
import java.util
import java.util.{Timer, TimerTask}
import org.apache.flink.api.scala._
import com.venn.common.Common
import com.venn.util.TwoStringSource
import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.serialization.SimpleStringSchema
import org.apache.flink.configuration.Configuration
import org.apache.flink.runtime.state.filesystem.FsStateBackend
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.api.{CheckpointingMode, TimeCharacteristic}
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer
import org.slf4j.LoggerFactory


      def query() = {
        logger.info("query mysql")
        try {
          Class.forName(driverName)
          conn = DriverManager.getConnection(jdbcUrl, username, password)
          ps = conn.prepareStatement("select id,name from venn.timer")
          val rs = ps.executeQuery

          while (!rs.isClosed && rs.next) {
            val id = rs.getString(1)
            val name = rs.getString(2)
            map.put(id, name)
          }
          logger.info("get config from db size : {}", map.size())

        } catch {
          case e@(_: ClassNotFoundException | _: SQLException) =>
            e.printStackTrace()
        } finally {
          if (conn != null) {
            conn.close()
          }
        }
      }
    })
//              .print()


    val sink = new FlinkKafkaProducer[String]("timer_out"
      , new SimpleStringSchema()
      , Common.getProp)
    stream.addSink(sink)
    env.execute(this.getClass.getName)

  }

} 
Example 118
Source File: MysqlSink1.scala    From flink-rookie   with Apache License 2.0 5 votes vote down vote up
package com.venn.stream.api.jdbcOutput

import java.sql.{Connection, DriverManager, PreparedStatement, SQLException}
import org.apache.flink.api.common.io.OutputFormat
import org.apache.flink.configuration.Configuration
import org.slf4j.{Logger, LoggerFactory}

class MysqlSink1 extends OutputFormat[User]{

  val logger: Logger = LoggerFactory.getLogger("MysqlSink1")
  var conn: Connection = _
  var ps: PreparedStatement = _
  val jdbcUrl = "jdbc:mysql://192.168.229.128:3306?useSSL=false&allowPublicKeyRetrieval=true"
  val username = "root"
  val password = "123456"
  val driverName = "com.mysql.jdbc.Driver"

  override def configure(parameters: Configuration): Unit = {
    // not need
  }

  override def open(taskNumber: Int, numTasks: Int): Unit = {
    Class.forName(driverName)
    try {
      Class.forName(driverName)
      conn = DriverManager.getConnection(jdbcUrl, username, password)

      // close auto commit
      conn.setAutoCommit(false)
    } catch {
      case e@(_: ClassNotFoundException | _: SQLException) =>
        logger.error("init mysql error")
        e.printStackTrace()
        System.exit(-1);
    }
  }

  override def writeRecord(user: User): Unit = {

    println("get user : " + user.toString)
    ps = conn.prepareStatement("insert into async.user(username, password, sex, phone) values(?,?,?,?)")
    ps.setString(1, user.username)
    ps.setString(2, user.password)
    ps.setInt(3, user.sex)
    ps.setString(4, user.phone)

    ps.execute()
    conn.commit()
  }

  override def close(): Unit = {

    if (conn != null){
      conn.commit()
      conn.close()
    }
  }
} 
Example 119
Source File: MysqlSink.scala    From flink-rookie   with Apache License 2.0 5 votes vote down vote up
package com.venn.stream.api.jdbcOutput

import java.sql.{Connection, DriverManager, PreparedStatement, SQLException}
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.sink.{RichSinkFunction, SinkFunction}
import org.slf4j.{Logger, LoggerFactory}

class MysqlSink extends RichSinkFunction[User] {

  val logger: Logger = LoggerFactory.getLogger("MysqlSink")
  var conn: Connection = _
  var ps: PreparedStatement = _
  val jdbcUrl = "jdbc:mysql://192.168.229.128:3306?useSSL=false&allowPublicKeyRetrieval=true"
  val username = "root"
  val password = "123456"
  val driverName = "com.mysql.jdbc.Driver"

  override def open(parameters: Configuration): Unit = {

    Class.forName(driverName)
    try {
      Class.forName(driverName)
      conn = DriverManager.getConnection(jdbcUrl, username, password)

      // close auto commit
      conn.setAutoCommit(false)
    } catch {
      case e@(_: ClassNotFoundException | _: SQLException) =>
        logger.error("init mysql error")
        e.printStackTrace()
        System.exit(-1);
    }
  }

  
  override def invoke(user: User, context: SinkFunction.Context[_]): Unit = {
    println("get user : " + user.toString)
    ps = conn.prepareStatement("insert into async.user(username, password, sex, phone) values(?,?,?,?)")
    ps.setString(1, user.username)
    ps.setString(2, user.password)
    ps.setInt(3, user.sex)
    ps.setString(4, user.phone)

    ps.execute()
    conn.commit()
  }



  override def close(): Unit = {
    if (conn != null){
      conn.commit()
      conn.close()
    }
  }
} 
Example 120
Source File: JdbcUtil.scala    From sundial   with MIT License 5 votes vote down vote up
package util

import java.sql.{Connection, Timestamp, ResultSet}
import java.util.Date
import scala.language.implicitConversions

object JdbcUtil {

  implicit def resultSetItr(resultSet: ResultSet): Stream[ResultSet] = {
    new Iterator[ResultSet] {
      def hasNext = resultSet.next()
      def next() = resultSet
    }.toStream
  }

  implicit def javaDate(ts: Timestamp): Date = {
    new Date(ts.getTime())
  }

  implicit def dateToTimestamp(date: Date) = {
    if (date != null)
      new Timestamp(date.getTime())
    else
      null
  }

  private def getNullable[T](rs: ResultSet, f: ResultSet => T): Option[T] = {
    val obj = f(rs)
    if (rs.wasNull()) {
      Option.empty
    } else {
      Some(obj)
    }
  }

  def getIntOption(rs: ResultSet, col: String) =
    getNullable(rs, rs => rs.getInt(col))

  def makeStringArray(seq: Seq[String])(implicit conn: Connection) = {
    conn.createArrayOf("varchar", seq.toArray[AnyRef])
  }

  def getStringArray(rs: ResultSet, col: String) = {
    Option(rs.getArray(col))
      .map(_.getArray().asInstanceOf[Array[String]].toList)
  }

} 
Example 121
Source File: PostgresShellCommandStateDao.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres

import java.sql.Connection
import java.util.UUID

import dao.ExecutableStateDao
import dao.postgres.common.ShellCommandStateTable
import dao.postgres.marshalling.PostgresBatchExecutorStatus
import model.ShellCommandState
import util.JdbcUtil._

class PostgresShellCommandStateDao(implicit conn: Connection)
    extends ExecutableStateDao[ShellCommandState] {

  override def loadState(taskId: UUID) = {
    import ShellCommandStateTable._
    val sql = s"SELECT * FROM $TABLE WHERE $COL_TASK_ID = ?"
    val stmt = conn.prepareStatement(sql)
    stmt.setObject(1, taskId)
    val rs = stmt.executeQuery()
    rs.map { row =>
        ShellCommandState(
          taskId = row.getObject(COL_TASK_ID).asInstanceOf[UUID],
          asOf = javaDate(row.getTimestamp(COL_AS_OF)),
          status = PostgresBatchExecutorStatus(rs.getString(COL_STATUS))
        )
      }
      .toList
      .headOption
  }

  override def saveState(state: ShellCommandState) = {
    import ShellCommandStateTable._
    val didUpdate = {
      val sql =
        s"""
           |UPDATE $TABLE
           |SET $COL_STATUS = ?::task_executor_status,
           |    $COL_AS_OF = ?
           |WHERE $COL_TASK_ID = ?
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setString(1, PostgresBatchExecutorStatus(state.status))
      stmt.setTimestamp(2, state.asOf)
      stmt.setObject(3, state.taskId)
      stmt.executeUpdate() > 0
    }
    if (!didUpdate) {
      val sql =
        s"""
           |INSERT INTO $TABLE
           |($COL_TASK_ID, $COL_AS_OF, $COL_STATUS)
           |VALUES
           |(?, ?, ?::task_executor_status)
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setObject(1, state.taskId)
      stmt.setTimestamp(2, state.asOf)
      stmt.setString(3, PostgresBatchExecutorStatus(state.status))
      stmt.execute()
    }
  }

} 
Example 122
Source File: PostgresBatchStateDao.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres

import java.sql.Connection
import java.util.UUID

import dao.ExecutableStateDao
import dao.postgres.marshalling.PostgresBatchExecutorStatus
import model.BatchContainerState
import util.JdbcUtil._

class PostgresBatchStateDao(implicit conn: Connection)
    extends ExecutableStateDao[BatchContainerState] {

  override def loadState(taskId: UUID) = {
    import dao.postgres.common.BatchStateTable._
    val sql = s"SELECT * FROM $TABLE WHERE $COL_TASK_ID = ?"
    val stmt = conn.prepareStatement(sql)
    stmt.setObject(1, taskId)
    val rs = stmt.executeQuery()
    rs.map { row =>
        BatchContainerState(
          taskId = row.getObject(COL_TASK_ID).asInstanceOf[UUID],
          asOf = javaDate(row.getTimestamp(COL_AS_OF)),
          status = PostgresBatchExecutorStatus(rs.getString(COL_STATUS)),
          jobName = rs.getString(COL_JOB_NAME),
          jobId = rs.getObject(COL_JOB_ID).asInstanceOf[UUID],
          logStreamName = Option(rs.getString(COL_LOGSTREAM_NAME))
        )
      }
      .toList
      .headOption
  }

  override def saveState(state: BatchContainerState) = {
    import dao.postgres.common.BatchStateTable._
    val didUpdate = {
      val sql =
        s"""
           |UPDATE $TABLE
           |SET
           |  $COL_STATUS = ?::batch_executor_status,
           |  $COL_AS_OF = ?,
           |  $COL_JOB_ID = ?,
           |  $COL_JOB_NAME = ?,
           |  $COL_LOGSTREAM_NAME = ?
           |WHERE $COL_TASK_ID = ?
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setString(1, PostgresBatchExecutorStatus(state.status))
      stmt.setTimestamp(2, state.asOf)
      stmt.setObject(3, state.jobId)
      stmt.setString(4, state.jobName)
      stmt.setString(5, state.logStreamName.orNull)
      stmt.setObject(6, state.taskId)
      stmt.executeUpdate() > 0
    }
    if (!didUpdate) {
      val sql =
        s"""
           |INSERT INTO $TABLE
           |($COL_TASK_ID, $COL_AS_OF, $COL_STATUS, $COL_JOB_ID, $COL_JOB_NAME, $COL_LOGSTREAM_NAME)
           |VALUES
           |(?, ?, ?::batch_executor_status, ?, ?, ?)
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setObject(1, state.taskId)
      stmt.setTimestamp(2, state.asOf)
      stmt.setString(3, PostgresBatchExecutorStatus(state.status))
      stmt.setObject(4, state.jobId)
      stmt.setString(5, state.jobName)
      stmt.setString(6, state.logStreamName.orNull)
      stmt.execute()
    }
  }

} 
Example 123
Source File: PostgresTaskMetadataDao.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres

import java.sql.Connection
import java.util.UUID

import dao.TaskMetadataDao
import dao.postgres.common.TaskMetadataTable
import model.TaskMetadataEntry
import util.JdbcUtil._

class PostgresTaskMetadataDao(implicit conn: Connection)
    extends TaskMetadataDao {

  override def loadMetadataForTask(taskId: UUID) = {
    import TaskMetadataTable._
    val sql = s"SELECT * FROM $TABLE WHERE $COL_TASK_ID = ?"
    val stmt = conn.prepareStatement(sql)
    stmt.setObject(1, taskId)
    stmt
      .executeQuery()
      .map { rs =>
        TaskMetadataEntry(
          id = rs.getObject(COL_ID).asInstanceOf[UUID],
          taskId = rs.getObject(COL_TASK_ID).asInstanceOf[UUID],
          when = javaDate(rs.getTimestamp(COL_WHEN)),
          key = rs.getString(COL_KEY),
          value = rs.getString(COL_VALUE)
        )
      }
      .toList
  }

  override def saveMetadataEntries(entries: Seq[TaskMetadataEntry]) = {
    import TaskMetadataTable._
    val sql =
      s"""
         |INSERT INTO $TABLE
         |($COL_ID, $COL_TASK_ID, $COL_WHEN, $COL_KEY, $COL_VALUE)
         |VALUES
         |(?, ?, ?, ?, ?)
       """.stripMargin
    val stmt = conn.prepareStatement(sql)
    entries.foreach { entry =>
      stmt.setObject(1, entry.id)
      stmt.setObject(2, entry.taskId)
      stmt.setTimestamp(3, entry.when)
      stmt.setString(4, entry.key)
      stmt.setString(5, entry.value)
      stmt.addBatch()
    }
    stmt.executeBatch()
  }

} 
Example 124
Source File: ProcessDefinitionMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}

import dao.postgres.common.ProcessDefinitionTable
import model.{
  EmailNotification,
  Notification,
  ProcessDefinition,
  ProcessOverlapAction
}
import util.JdbcUtil._

object ProcessDefinitionMarshaller {

  private val postgresJsonMarshaller = new PostgresJsonMarshaller

  def marshal(definition: ProcessDefinition,
              stmt: PreparedStatement,
              columns: Seq[String],
              startIndex: Int = 1)(implicit conn: Connection) = {
    import ProcessDefinitionTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_NAME => stmt.setString(index, definition.name)
        case COL_DESCRIPTION =>
          stmt.setString(index, definition.description.orNull)
        case COL_SCHEDULE =>
          stmt.setString(
            index,
            definition.schedule.map(PostgresJsonMarshaller.toJson).orNull)
        case COL_OVERLAP_ACTION =>
          stmt.setString(index, definition.overlapAction match {
            case ProcessOverlapAction.Wait      => OVERLAP_WAIT
            case ProcessOverlapAction.Terminate => OVERLAP_TERMINATE
          })
        case COL_TEAMS => stmt.setString(index, "[]")
        case COL_NOTIFICATIONS =>
          stmt.setString(
            index,
            postgresJsonMarshaller.toJson(definition.notifications))
        case COL_DISABLED   => stmt.setBoolean(index, definition.isPaused)
        case COL_CREATED_AT => stmt.setTimestamp(index, definition.createdAt)
      }
      index += 1
    }
  }

  def unmarshal(rs: ResultSet): ProcessDefinition = {
    import ProcessDefinitionTable._
    ProcessDefinition(
      name = rs.getString(COL_NAME),
      description = Option(rs.getString(COL_DESCRIPTION)),
      schedule = Option(rs.getString(COL_SCHEDULE))
        .map(PostgresJsonMarshaller.toSchedule),
      overlapAction = rs.getString(COL_OVERLAP_ACTION) match {
        case OVERLAP_WAIT      => ProcessOverlapAction.Wait
        case OVERLAP_TERMINATE => ProcessOverlapAction.Terminate
      },
      notifications = this.getNotifications(rs),
      isPaused = rs.getBoolean(COL_DISABLED),
      createdAt = javaDate(rs.getTimestamp(COL_CREATED_AT))
    )
  }

  private def getNotifications(rs: ResultSet): Seq[Notification] = {
    import ProcessDefinitionTable._
    val teams = PostgresJsonMarshaller
      .toTeams(rs.getString(COL_TEAMS))
      .map(team => EmailNotification(team.name, team.email, team.notifyAction))
    val notifications =
      postgresJsonMarshaller.toNotifications(rs.getString(COL_NOTIFICATIONS))
    notifications ++ teams
  }

} 
Example 125
Source File: ConnectionPool.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.common

import java.sql.Connection

trait ConnectionPool {
  def fetchConnection(): Connection
  def withConnection[T](f: Connection => T) = {
    val connection = fetchConnection()
    connection.setAutoCommit(false)
    try {
      f(connection)
    } finally {
      connection.close()
    }
  }
} 
Example 126
Source File: PostgresEmrStateDao.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres

import java.sql.Connection
import java.util.UUID

import dao.ExecutableStateDao
import dao.postgres.common.EmrStateTable
import dao.postgres.marshalling.PostgresEmrExecutorStatus
import model.EmrJobState
import util.JdbcUtil._

class PostgresEmrStateDao(implicit conn: Connection)
    extends ExecutableStateDao[EmrJobState] {

  override def loadState(taskId: UUID) = {
    import EmrStateTable._
    val sql = s"SELECT * FROM $TABLE WHERE $COL_TASK_ID = ?"
    val stmt = conn.prepareStatement(sql)
    stmt.setObject(1, taskId)
    val rs = stmt.executeQuery()
    rs.map { row =>
        EmrJobState(
          taskId = row.getObject(COL_TASK_ID).asInstanceOf[UUID],
          jobName = row.getString(COL_JOB_NAME),
          clusterId = row.getString(COL_CLUSTER_ID),
          stepIds = row.getString(COL_STEP_ID).split(","),
          region = row.getString(COL_REGION),
          asOf = javaDate(row.getTimestamp(COL_AS_OF)),
          status = PostgresEmrExecutorStatus(rs.getString(COL_STATUS))
        )
      }
      .toList
      .headOption
  }

  override def saveState(state: EmrJobState) = {
    import EmrStateTable._
    val didUpdate = {
      val sql =
        s"""
           |UPDATE $TABLE
           |SET
           |  $COL_STATUS = ?::emr_executor_status,
           |  $COL_AS_OF = ?
           |WHERE $COL_TASK_ID = ?
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setString(1, PostgresEmrExecutorStatus(state.status))
      stmt.setTimestamp(2, state.asOf)
      stmt.setObject(3, state.taskId)
      stmt.executeUpdate() > 0
    }

    if (!didUpdate) {
      val sql =
        s"""
           |INSERT INTO $TABLE
           |($COL_TASK_ID, $COL_JOB_NAME, $COL_CLUSTER_ID, $COL_STEP_ID, $COL_REGION, $COL_AS_OF, $COL_STATUS)
           |VALUES
           |(?, ?, ?, ?, ?, ?, ?::emr_executor_status)
         """.stripMargin
      val stmt = conn.prepareStatement(sql)
      stmt.setObject(1, state.taskId)
      stmt.setString(2, state.jobName)
      stmt.setString(3, state.clusterId)
      stmt.setString(4, state.stepIds.mkString(","))
      stmt.setString(5, state.region)
      stmt.setTimestamp(6, state.asOf)
      stmt.setString(7, PostgresEmrExecutorStatus(state.status))
      stmt.execute()
    }
  }

} 
Example 127
Source File: TaskTriggerRequestMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.UUID
import dao.postgres.common.TaskTriggerRequestTable
import model.TaskTriggerRequest
import util.JdbcUtil._

object TaskTriggerRequestMarshaller {

  def marshal(request: TaskTriggerRequest,
              stmt: PreparedStatement,
              columns: Seq[String],
              startIndex: Int = 1)(implicit conn: Connection) = {
    import TaskTriggerRequestTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_REQUEST_ID => stmt.setObject(index, request.requestId)
        case COL_PROCESS_DEF_NAME =>
          stmt.setString(index, request.processDefinitionName)
        case COL_TASK_DEF_NAME =>
          stmt.setString(index, request.taskDefinitionName)
        case COL_REQUESTED_AT => stmt.setTimestamp(index, request.requestedAt)
        case COL_STARTED_PROCESS_ID =>
          stmt.setObject(index, request.startedProcessId.orNull)
      }
      index += 1
    }
  }

  def unmarshal(rs: ResultSet): TaskTriggerRequest = {
    import TaskTriggerRequestTable._
    TaskTriggerRequest(
      requestId = rs.getObject(COL_REQUEST_ID).asInstanceOf[UUID],
      processDefinitionName = rs.getString(COL_PROCESS_DEF_NAME),
      taskDefinitionName = rs.getString(COL_TASK_DEF_NAME),
      requestedAt = javaDate(rs.getTimestamp(COL_REQUESTED_AT)),
      startedProcessId =
        Option(rs.getObject(COL_STARTED_PROCESS_ID)).map(_.asInstanceOf[UUID])
    )
  }

} 
Example 128
Source File: TaskMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.UUID
import dao.postgres.common.TaskTable
import model.{Task, TaskStatus}
import util.JdbcUtil._

object TaskMarshaller {

  def unmarshalTask(rs: ResultSet): Task = {
    import TaskTable._
    Task(
      id = rs.getObject(COL_ID).asInstanceOf[UUID],
      processId = rs.getObject(COL_PROCESS_ID).asInstanceOf[UUID],
      processDefinitionName = rs.getString(COL_PROC_DEF_NAME),
      taskDefinitionName = rs.getString(COL_TASK_DEF_NAME),
      executable =
        PostgresJsonMarshaller.toExecutable(rs.getString(COL_EXECUTABLE)),
      previousAttempts = rs.getInt(COL_ATTEMPTS),
      startedAt = javaDate(rs.getTimestamp(COL_STARTED)),
      status = rs.getString(COL_STATUS) match {
        case STATUS_SUCCEEDED =>
          TaskStatus.Success(javaDate(rs.getTimestamp(COL_ENDED_AT)))
        case STATUS_FAILED =>
          TaskStatus.Failure(javaDate(rs.getTimestamp(COL_ENDED_AT)),
                             Option(rs.getString(COL_REASON)))
        case STATUS_RUNNING => TaskStatus.Running()
      }
    )
  }

  def marshalTask(task: Task,
                  stmt: PreparedStatement,
                  columns: Seq[String],
                  startIndex: Int = 1)(implicit conn: Connection) = {
    import TaskTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_ID         => stmt.setObject(index, task.id)
        case COL_PROCESS_ID => stmt.setObject(index, task.processId)
        case COL_PROC_DEF_NAME =>
          stmt.setString(index, task.processDefinitionName)
        case COL_TASK_DEF_NAME => stmt.setString(index, task.taskDefinitionName)
        case COL_EXECUTABLE =>
          stmt.setString(index, PostgresJsonMarshaller.toJson(task.executable))
        case COL_ATTEMPTS => stmt.setInt(index, task.previousAttempts)
        case COL_STARTED  => stmt.setTimestamp(index, task.startedAt)
        case COL_STATUS =>
          stmt.setString(index, task.status match {
            case TaskStatus.Success(_)    => STATUS_SUCCEEDED
            case TaskStatus.Failure(_, _) => STATUS_FAILED
            case TaskStatus.Running()     => STATUS_RUNNING
          })
        case COL_REASON =>
          stmt.setString(index, task.status match {
            case TaskStatus.Failure(_, reasons) => reasons.mkString(",")
            case _                              => null
          })
        case COL_ENDED_AT =>
          stmt.setTimestamp(index, task.endedAt.getOrElse(null))
      }
      index += 1
    }
  }

} 
Example 129
Source File: ProcessTriggerRequestMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.UUID

import dao.postgres.common.{ProcessTriggerRequestTable, TaskTriggerRequestTable}
import model.ProcessTriggerRequest
import util.JdbcUtil._

object ProcessTriggerRequestMarshaller {

  def marshal(request: ProcessTriggerRequest,
              stmt: PreparedStatement,
              columns: Seq[String],
              startIndex: Int = 1)(implicit conn: Connection) = {
    import ProcessTriggerRequestTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_REQUEST_ID => stmt.setObject(index, request.requestId)
        case COL_PROCESS_DEF_NAME =>
          stmt.setString(index, request.processDefinitionName)
        case COL_REQUESTED_AT => stmt.setTimestamp(index, request.requestedAt)
        case COL_STARTED_PROCESS_ID =>
          stmt.setObject(index, request.startedProcessId.orNull)
        case COL_TASK_FILTER =>
          stmt.setArray(index, request.taskFilter.map(makeStringArray).orNull)
      }
      index += 1
    }
  }

  def unmarshal(rs: ResultSet): ProcessTriggerRequest = {
    import ProcessTriggerRequestTable._
    ProcessTriggerRequest(
      requestId = rs.getObject(COL_REQUEST_ID).asInstanceOf[UUID],
      processDefinitionName = rs.getString(COL_PROCESS_DEF_NAME),
      requestedAt = javaDate(rs.getTimestamp(COL_REQUESTED_AT)),
      startedProcessId =
        Option(rs.getObject(COL_STARTED_PROCESS_ID)).map(_.asInstanceOf[UUID]),
      taskFilter = getStringArray(rs, COL_TASK_FILTER)
    )
  }

} 
Example 130
Source File: TaskDefinitionTemplateMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}

import dao.postgres.common.TaskDefinitionTemplateTable
import model._
import util.JdbcUtil._

object TaskDefinitionTemplateMarshaller {

  def marshal(definition: TaskDefinitionTemplate,
              stmt: PreparedStatement,
              columns: Seq[String],
              startIndex: Int = 1)(implicit conn: Connection) = {
    import TaskDefinitionTemplateTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_NAME => stmt.setString(index, definition.name)
        case COL_PROC_DEF_NAME =>
          stmt.setString(index, definition.processDefinitionName)
        case COL_EXECUTABLE =>
          stmt.setString(index,
                         PostgresJsonMarshaller.toJson(definition.executable))
        case COL_MAX_ATTEMPTS =>
          stmt.setInt(index, definition.limits.maxAttempts)
        case COL_MAX_EXECUTION_TIME =>
          stmt.setObject(index,
                         definition.limits.maxExecutionTimeSeconds.orNull)
        case COL_BACKOFF_SECONDS =>
          stmt.setInt(index, definition.backoff.seconds)
        case COL_BACKOFF_EXPONENT =>
          stmt.setDouble(index, definition.backoff.exponent)
        case COL_REQUIRED_DEPS =>
          stmt.setArray(index,
                        makeStringArray(definition.dependencies.required))
        case COL_OPTIONAL_DEPS =>
          stmt.setArray(index,
                        makeStringArray(definition.dependencies.optional))
        case COL_REQUIRE_EXPLICIT_SUCCESS =>
          stmt.setBoolean(index, definition.requireExplicitSuccess)
      }
      index += 1
    }
  }

  def unmarshal(rs: ResultSet): TaskDefinitionTemplate = {
    import TaskDefinitionTemplateTable._
    TaskDefinitionTemplate(
      name = rs.getString(COL_NAME),
      processDefinitionName = rs.getString(COL_PROC_DEF_NAME),
      executable =
        PostgresJsonMarshaller.toExecutable(rs.getString(COL_EXECUTABLE)),
      limits = TaskLimits(
        maxAttempts = rs.getInt(COL_MAX_ATTEMPTS),
        maxExecutionTimeSeconds = getIntOption(rs, COL_MAX_EXECUTION_TIME)
      ),
      backoff = TaskBackoff(
        seconds = rs.getInt(COL_BACKOFF_SECONDS),
        exponent = rs.getDouble(COL_BACKOFF_EXPONENT)
      ),
      dependencies = TaskDependencies(
        required = getStringArray(rs, COL_REQUIRED_DEPS).getOrElse(Seq.empty),
        optional = getStringArray(rs, COL_OPTIONAL_DEPS).getOrElse(Seq.empty)
      ),
      requireExplicitSuccess = rs.getBoolean(COL_REQUIRE_EXPLICIT_SUCCESS)
    )
  }

} 
Example 131
Source File: ProcessMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet, Timestamp}
import java.util.UUID
import dao.postgres.common.ProcessTable
import model.{Process, ProcessStatus}
import util.JdbcUtil._

object ProcessMarshaller {

  def unmarshalProcess(rs: ResultSet): Process = {
    import ProcessTable._
    Process(
      id = rs.getObject(COL_ID).asInstanceOf[UUID],
      processDefinitionName = rs.getString(COL_DEF_NAME),
      startedAt = javaDate(rs.getTimestamp(COL_STARTED)),
      status = rs.getString(COL_STATUS) match {
        case STATUS_SUCCEEDED =>
          ProcessStatus.Succeeded(javaDate(rs.getTimestamp(COL_ENDED_AT)))
        case STATUS_FAILED =>
          ProcessStatus.Failed(javaDate(rs.getTimestamp(COL_ENDED_AT)))
        case STATUS_RUNNING => ProcessStatus.Running()
      },
      taskFilter = getStringArray(rs, COL_TASK_FILTER)
    )
  }

  def marshalProcess(process: Process,
                     stmt: PreparedStatement,
                     columns: Seq[String],
                     startIndex: Int = 1)(implicit conn: Connection) = {
    import ProcessTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_ID => stmt.setObject(index, process.id)
        case COL_DEF_NAME =>
          stmt.setString(index, process.processDefinitionName)
        case COL_STARTED =>
          stmt.setTimestamp(index, new Timestamp(process.startedAt.getTime()))
        case COL_ENDED_AT =>
          stmt.setTimestamp(index, process.endedAt.getOrElse(null))
        case COL_STATUS =>
          stmt.setString(
            index,
            process.status match {
              case ProcessStatus.Succeeded(_) => STATUS_SUCCEEDED
              case ProcessStatus.Failed(_)    => STATUS_FAILED
              case ProcessStatus.Running()    => STATUS_RUNNING
            }
          )
        case COL_TASK_FILTER =>
          stmt.setArray(index,
                        process.taskFilter.map(makeStringArray).getOrElse(null))
      }
      index += 1
    }
  }

} 
Example 132
Source File: TaskDefinitionMarshaller.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.UUID
import dao.postgres.common.TaskDefinitionTable
import model.{TaskBackoff, TaskDefinition, TaskDependencies, TaskLimits}
import util.JdbcUtil._

object TaskDefinitionMarshaller {

  def marshal(definition: TaskDefinition,
              stmt: PreparedStatement,
              columns: Seq[String],
              startIndex: Int = 1)(implicit conn: Connection) = {
    import TaskDefinitionTable._
    var index = startIndex
    columns.foreach { col =>
      col match {
        case COL_NAME    => stmt.setString(index, definition.name)
        case COL_PROC_ID => stmt.setObject(index, definition.processId)
        case COL_EXECUTABLE =>
          stmt.setString(index,
                         PostgresJsonMarshaller.toJson(definition.executable))
        case COL_MAX_ATTEMPTS =>
          stmt.setInt(index, definition.limits.maxAttempts)
        case COL_MAX_EXECUTION_TIME =>
          stmt.setObject(index,
                         definition.limits.maxExecutionTimeSeconds.orNull)
        case COL_BACKOFF_SECONDS =>
          stmt.setInt(index, definition.backoff.seconds)
        case COL_BACKOFF_EXPONENT =>
          stmt.setDouble(index, definition.backoff.exponent)
        case COL_REQUIRED_DEPS =>
          stmt.setArray(index,
                        makeStringArray(definition.dependencies.required))
        case COL_OPTIONAL_DEPS =>
          stmt.setArray(index,
                        makeStringArray(definition.dependencies.optional))
        case COL_REQUIRE_EXPLICIT_SUCCESS =>
          stmt.setBoolean(index, definition.requireExplicitSuccess)
      }
      index += 1
    }
  }

  def unmarshal(rs: ResultSet): TaskDefinition = {
    import TaskDefinitionTable._
    TaskDefinition(
      name = rs.getString(COL_NAME),
      processId = rs.getObject(COL_PROC_ID).asInstanceOf[UUID],
      executable =
        PostgresJsonMarshaller.toExecutable(rs.getString(COL_EXECUTABLE)),
      limits = TaskLimits(
        maxAttempts = rs.getInt(COL_MAX_ATTEMPTS),
        maxExecutionTimeSeconds = getIntOption(rs, COL_MAX_EXECUTION_TIME)
      ),
      backoff = TaskBackoff(
        seconds = rs.getInt(COL_BACKOFF_SECONDS),
        exponent = rs.getDouble(COL_BACKOFF_EXPONENT)
      ),
      dependencies = TaskDependencies(
        required = getStringArray(rs, COL_REQUIRED_DEPS).getOrElse(Seq.empty),
        optional = getStringArray(rs, COL_OPTIONAL_DEPS).getOrElse(Seq.empty)
      ),
      requireExplicitSuccess = rs.getBoolean(COL_REQUIRE_EXPLICIT_SUCCESS)
    )
  }

} 
Example 133
Source File: H2Queries.scala    From daml   with Apache License 2.0 4 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.sql.Connection

import anorm.SqlParser._
import anorm._
import com.daml.ledger.on.sql.Index
import com.daml.ledger.on.sql.queries.Queries._
import com.daml.ledger.participant.state.v1.LedgerId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}

import scala.util.Try

final class H2Queries(override protected implicit val connection: Connection)
    extends Queries
    with CommonQueries {
  override def updateOrRetrieveLedgerId(providedLedgerId: LedgerId): Try[LedgerId] = Try {
    SQL"MERGE INTO #$MetaTable USING DUAL ON table_key = $MetaTableKey WHEN NOT MATCHED THEN INSERT (table_key, ledger_id) VALUES ($MetaTableKey, $providedLedgerId)"
      .executeInsert()
    SQL"SELECT ledger_id FROM #$MetaTable WHERE table_key = $MetaTableKey"
      .as(str("ledger_id").single)
  }

  override def insertRecordIntoLog(key: Key, value: Value): Try[Index] =
    Try {
      SQL"INSERT INTO #$LogTable (entry_id, envelope) VALUES ($key, $value)"
        .executeInsert()
      ()
    }.flatMap(_ => lastInsertId())

  override protected val updateStateQuery: String =
    s"MERGE INTO $StateTable VALUES ({key}, {value})"

  private def lastInsertId(): Try[Index] = Try {
    SQL"CALL IDENTITY()"
      .as(long("IDENTITY()").single)
  }

  override final def truncate(): Try[Unit] = Try {
    SQL"truncate #$StateTable".executeUpdate()
    SQL"truncate #$LogTable".executeUpdate()
    SQL"truncate #$MetaTable".executeUpdate()
    ()
  }
}

object H2Queries {
  def apply(connection: Connection): Queries = {
    implicit val conn: Connection = connection
    new H2Queries
  }
}