java.sql.Statement Scala Examples

The following examples show how to use java.sql.Statement. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DNSstat.scala    From jdbcsink   with Apache License 2.0 6 votes vote down vote up
import org.apache.spark.sql.SparkSession
import java.util.Properties
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.{from_json,window}
import java.sql.{Connection,Statement,DriverManager}
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.Row

class JDBCSink() extends ForeachWriter[Row]{
 val driver = "com.mysql.jdbc.Driver"
      var connection:Connection = _
      var statement:Statement = _

    def open(partitionId: Long,version: Long): Boolean = {
        Class.forName(driver)
        connection = DriverManager.getConnection("jdbc:mysql://10.88.1.102:3306/aptwebservice", "root", "mysqladmin")
        statement = connection.createStatement
        true
      }
      def process(value: Row): Unit = {
        statement.executeUpdate("replace into DNSStat(ip,domain,time,count) values(" 
                                    + "'" + value.getString(0) + "'" + ","//ip
                                    + "'" + value.getString(1) + "'" + ","//domain
                                    + "'" + value.getTimestamp(2) + "'" + "," //time
                                    + value.getLong(3) //count
                                    + ")") 
      }

      def close(errorOrNull: Throwable): Unit = {
        connection.close
      }
}

object DNSstatJob{

val schema: StructType = StructType(
        Seq(StructField("Vendor", StringType,true),
         StructField("Id", IntegerType,true),
         StructField("Time", LongType,true),
         StructField("Conn", StructType(Seq(
                                        StructField("Proto", IntegerType, true), 
                                        StructField("Sport", IntegerType, true), 
                                        StructField("Dport", IntegerType, true), 
                                        StructField("Sip", StringType, true), 
                                        StructField("Dip", StringType, true)
                                        )), true),
        StructField("Dns", StructType(Seq(
                                        StructField("Domain", StringType, true), 
                                        StructField("IpCount", IntegerType, true), 
                                        StructField("Ip", StringType, true) 
                                        )), true)))

    def main(args: Array[String]) {
    val spark=SparkSession
          .builder
          .appName("DNSJob")
          .config("spark.some.config.option", "some-value")
          .getOrCreate()
    import spark.implicits._
    val connectionProperties = new Properties()
    connectionProperties.put("user", "root")
    connectionProperties.put("password", "mysqladmin")
    val bruteForceTab = spark.read
                .jdbc("jdbc:mysql://10.88.1.102:3306/aptwebservice", "DNSTab",connectionProperties)
    bruteForceTab.registerTempTable("DNSTab")
    val lines = spark
          .readStream
          .format("kafka")
          .option("kafka.bootstrap.servers", "10.94.1.110:9092")
          .option("subscribe","xdr")
          //.option("startingOffsets","earliest")
          .option("startingOffsets","latest")
          .load()
          .select(from_json($"value".cast(StringType),schema).as("jsonData"))
    lines.registerTempTable("xdr")
    val filterDNS = spark.sql("select CAST(from_unixtime(xdr.jsonData.Time DIV 1000000) as timestamp) as time,xdr.jsonData.Conn.Sip as sip, xdr.jsonData.Dns.Domain from xdr inner join DNSTab on xdr.jsonData.Dns.domain = DNSTab.domain")
    
    val windowedCounts = filterDNS
                        .withWatermark("time","5 minutes")
                        .groupBy(window($"time", "1 minutes", "1 minutes"),$"sip",$"domain")
                        .count()
                        .select($"sip",$"domain",$"window.start",$"count")

    val writer = new JDBCSink()
    val query = windowedCounts
       .writeStream
        .foreach(writer)
        .outputMode("update")
        .option("checkpointLocation","/checkpoint/")
        .start()
        query.awaitTermination() 
   } 
} 
Example 2
Source File: JDBCSink.scala    From BigData-News   with Apache License 2.0 5 votes vote down vote up
package com.vita.spark

import java.sql.{Connection, ResultSet, SQLException, Statement}

import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.{ForeachWriter, Row}

/**
  * 处理从StructuredStreaming中向mysql中写入数据
  */
class JDBCSink(url: String, username: String, password: String) extends ForeachWriter[Row] {

  var statement: Statement = _
  var resultSet: ResultSet = _
  var connection: Connection = _

  override def open(partitionId: Long, version: Long): Boolean = {
    connection = new MySqlPool(url, username, password).getJdbcConn()
    statement = connection.createStatement();
    print("open")
    return true
  }

  override def process(value: Row): Unit = {
    println("process step one")
    val titleName = value.getAs[String]("titleName").replaceAll("[\\[\\]]", "")
    val count = value.getAs[Long]("count")

    val querySql = "select 1 from webCount where titleName = '" + titleName + "'"
    val insertSql = "insert into webCount(titleName,count) values('" + titleName + "' , '" + count + "')"
    val updateSql = "update webCount set count = " + count + " where titleName = '" + titleName + "'"
    println("process step two")
    try {
      //查看连接是否成功
      var resultSet = statement.executeQuery(querySql)
      if (resultSet.next()) {
        println("updateSql")
        statement.executeUpdate(updateSql)
      } else {
        println("insertSql")
        statement.execute(insertSql)
      }

    } catch {
      case ex: SQLException => {
        println("SQLException")
      }
      case ex: Exception => {
        println("Exception")
      }
      case ex: RuntimeException => {
        println("RuntimeException")
      }
      case ex: Throwable => {
        println("Throwable")
      }
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    if (statement == null) {
      statement.close()
    }
    if (connection == null) {
      connection.close()
    }
  }
} 
Example 3
Source File: PostgresInteropTest.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.alchemy.spark.expressions.hll

import java.sql.{DriverManager, ResultSet, Statement}

import com.swoop.alchemy.spark.expressions.hll.functions._
import com.swoop.test_utils.SparkSessionSpec
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.scalatest.{Matchers, WordSpec}


case class Postgres(user: String, database: String, port: Int) {
  val con_str = s"jdbc:postgresql://localhost:$port/$database?user=$user"

  def execute[T](query: String, handler: ResultSet => T): T =
    execute(stm => handler(stm.executeQuery(query)))

  def update(query: String): Unit =
    execute(_.executeUpdate(query))

  def sparkRead(schema: String, table: String)(implicit spark: SparkSession): DataFrame =
    spark.read
      .format("jdbc")
      .option("url", s"jdbc:postgresql:${database}")
      .option("dbtable", s"${schema}.${table}")
      .option("user", user)
      .load()

  def sparkWrite(schema: String, table: String)(df: DataFrame): Unit =
    df.write
      .format("jdbc")
      .option("url", s"jdbc:postgresql:${database}")
      .option("dbtable", s"${schema}.${table}")
      .option("user", user)
      .save()

  private def execute[T](fn: Statement => T): T = {
    val conn = DriverManager.getConnection(con_str)
    try {
      val stm = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
      fn(stm)
    } finally {
      conn.close()
    }
  }
}


class PostgresInteropTest extends WordSpec with Matchers with SparkSessionSpec {

  import testImplicits._

  lazy val pg = Postgres("postgres", "postgres", 5432)

  "Postgres interop" should {
    "calculate same results" in {
      // use Aggregate Knowledge (Postgres-compatible) HLL implementation
      spark.conf.set(IMPLEMENTATION_CONFIG_KEY, "AGKN")

      // init Postgres extension for database
      pg.update("CREATE EXTENSION IF NOT EXISTS hll;")

      // create some random not-entirely distinct rows
      val rand = new scala.util.Random(42)
      val n = 100000
      val randomDF = sc.parallelize(
        Seq.fill(n) {
          (rand.nextInt(24), rand.nextInt(n))
        }
      ).toDF("hour", "id").cache

      // create hll aggregates (by hour)
      val byHourDF = randomDF.groupBy("hour").agg(hll_init_agg("id", .39).as("hll_id")).cache

      // send hlls to postgres
      pg.update("DROP TABLE IF EXISTS spark_hlls CASCADE;")
      pg.sparkWrite("public", "spark_hlls")(byHourDF)

      // convert hll column from `bytea` to `hll` type
      pg.update(
        """
          |ALTER TABLE spark_hlls
          |ALTER COLUMN hll_id TYPE hll USING CAST (hll_id AS hll);
          |""".stripMargin
      )

      // re-aggregate all hours in Spark
      val distinctSpark = byHourDF.select(hll_cardinality(hll_merge(byHourDF("hll_id")))).as[Long].first()
      // re-aggregate all hours in Postgres
      val distinctPostgres = pg.execute(
        "SELECT CAST (hll_cardinality(hll_union_agg(hll_id)) as Integer) AS approx FROM spark_hlls",
        (rs) => {
          rs.next;
          rs.getInt("approx")
        }
      )

      distinctSpark should be(distinctPostgres)
    }
  }

} 
Example 4
Source File: SqlAlertTriggerTest.scala    From pulse   with Apache License 2.0 5 votes vote down vote up
package io.phdata.pulse.alertengine.trigger

import java.sql.{ DriverManager, Statement }

import io.phdata.pulse.alertengine.{ AlertsDb, TestObjectGenerator }
import io.phdata.pulse.solr.TestUtil
import org.scalatest.{ BeforeAndAfterAll, BeforeAndAfterEach, FunSuite }

class SqlAlertTriggerTest extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
  private val applicationName: String = "sql_test_" + TestUtil.randomIdentifier()
  private val dbUrl                   = s"jdbc:h2:mem:$applicationName;DB_CLOSE_DELAY=-1"

  override def beforeEach(): Unit = {
    super.beforeEach()
    AlertsDb.reset()
    prepareDatabase()
  }

  override def afterAll(): Unit =
    withStatement(statement => statement.execute("DROP ALL OBJECTS DELETE FILES;"))

  private def withStatement(function: Statement => Unit): Unit = {
    val connection = DriverManager.getConnection(dbUrl)
    try {
      val statement = connection.createStatement()
      try {
        function.apply(statement)
      } finally {
        statement.close()
      }
    } finally {
      connection.close()
    }
  }

  private def prepareDatabase(): Unit =
    withStatement { statement =>
      statement.execute("DROP ALL OBJECTS DELETE FILES;")
      statement.execute(s"""CREATE TABLE $applicationName (
           |id int not null,
           |error boolean not null,
           |message varchar(255) not null,
           |);""".stripMargin)
    }

  test("query returns matching documents") {
    withStatement { statement =>
      statement.execute(s"""INSERT INTO $applicationName (id, error, message) VALUES
           |(1, true, 'sad'),
           |(3, true, 'very sad'),
           |(2, false, 'happy');""".stripMargin)
    }
    val alertRule =
      TestObjectGenerator.alertRule(
        query = s"""select * from $applicationName
           |where error = true
           |order by id""".stripMargin,
        retryInterval = 1,
        resultThreshold = Some(1),
        alertProfiles = List("[email protected]")
      )
    val expectedDocuments = Seq(
      Map("id" -> 1, "error" -> true, "message" -> "sad"),
      Map("id" -> 3, "error" -> true, "message" -> "very sad")
    )

    val trigger = new SqlAlertTrigger(dbUrl)
    val result  = trigger.query(applicationName, alertRule)
    assertResult(expectedDocuments)(result)
  }

  test("query returns no documents") {
    val alertRule = TestObjectGenerator.alertRule(query = s"select * from $applicationName")

    val trigger = new SqlAlertTrigger(dbUrl)
    assertResult(Seq.empty)(trigger.query(applicationName, alertRule))
  }

  test("invalid query") {
    val alertRule = TestObjectGenerator.alertRule()

    val trigger = new SqlAlertTrigger(dbUrl)
    assertThrows[Exception](trigger.query(applicationName, alertRule))
  }

  test("connection with options") {
    val alertRule = TestObjectGenerator.alertRule(query = s"select * from $applicationName")

    val trigger = new SqlAlertTrigger(dbUrl, dbOptions = Map("hello" -> "stuff"))
    trigger.query(applicationName, alertRule)
  }

  test("dbUrl null") {
    assertThrows[IllegalArgumentException](new SqlAlertTrigger(null))
  }

  test("dbUrl empty") {
    assertThrows[IllegalArgumentException](new SqlAlertTrigger(""))
  }

} 
Example 5
Source File: SapThriftJdbcTest.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.sap.thriftserver

import java.sql.{DriverManager, Statement}

import org.apache.hive.jdbc.HiveDriver
import org.scalatest.{BeforeAndAfterAll, FunSuite}


abstract class SapThriftJdbcTest(val thriftServer: SapThriftServer2Test){

  def jdbcUri: String

  def withMultipleConnectionJdbcStatement(fs: (Statement => Unit)*) {
    val user = System.getProperty("user.name")
    val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
    val statements = connections.map(_.createStatement())

    try {
      statements.zip(fs).foreach { case (s, f) => f(s) }
    } finally {
      statements.foreach(_.close())
      connections.foreach(_.close())
    }
  }

  def withJdbcStatement(f: Statement => Unit): Unit = {
    withMultipleConnectionJdbcStatement(f)
  }

}

class SapThriftJdbcHiveDriverTest(override val thriftServer: SapThriftServer2Test)
  extends SapThriftJdbcTest(thriftServer) {
  Class.forName(classOf[HiveDriver].getCanonicalName)

  override def jdbcUri: String = if (thriftServer.mode == ServerMode.http) {
    s"""jdbc:hive2://${thriftServer.getServerAdressAndPort()}/
        |default?
        |hive.server2.transport.mode=http;
        |hive.server2.thrift.http.path=cliservice
     """.stripMargin.split("\n").mkString.trim
  } else {
    s"jdbc:hive2://${thriftServer.getServerAdressAndPort()}/"
  }

} 
Example 6
Source File: ThriftServerTest.scala    From Hive-JDBC-Proxy   with Apache License 2.0 5 votes vote down vote up
package com.enjoyyin.hive.proxy.jdbc.test

import java.sql.{Connection, DriverManager, ResultSet, Statement}

import com.enjoyyin.hive.proxy.jdbc.util.Utils


private object ThriftServerTest extends App {
  val sql = """show tables"""
  val test_url = "jdbc:hive2://localhost:10001/default"
  Class.forName("org.apache.hive.jdbc.HiveDriver")
  def test(index: Int) = {
    var conn: Connection = null
    var stmt: Statement = null
    var rs: ResultSet = null
    Utils.tryFinally {
      conn = DriverManager.getConnection(test_url, "hduser0009", "")
      stmt = conn.createStatement
      rs = stmt.executeQuery(sql)
      while(rs.next) {
        println ("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", " + index + ".tables => " + rs.getObject(1))
      }
      println("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", ready to close " + index)
    } {
      if(rs != null) Utils.tryIgnoreError(rs.close())
      if(stmt != null) Utils.tryIgnoreError(stmt.close())
      if(conn != null) Utils.tryIgnoreError(conn.close())
    }
  }
  (0 until 8).foreach(i => new Thread {
    setName("thread-" + i)
    override def run(): Unit = {
      Utils.tryCatch(test(i)) { t =>
        println("Date: " + Utils.dateFormat(System.currentTimeMillis) + ", " + i + " has occur an error.")
        t.printStackTrace()
      }
    }
  }.start())
} 
Example 7
Source File: Database.scala    From schedoscope   with Apache License 2.0 5 votes vote down vote up
package org.schedoscope.test

import java.sql.{Connection, ResultSet, Statement}

import org.schedoscope.dsl.{FieldLike, View}
import org.schedoscope.schema.ddl.HiveQl

import scala.collection.mutable.{HashMap, ListBuffer}

class Database(conn: Connection, url: String) {

  def selectForViewByQuery(v: View, query: String, orderByField: Option[FieldLike[_]]): List[Map[String, Any]] = {
    val res = ListBuffer[Map[String, Any]]()
    var statement: Statement = null
    var rs: ResultSet = null

    try {
      statement = conn.createStatement()
      rs = statement.executeQuery(query)

      while (rs.next()) {
        val row = HashMap[String, Any]()
        v.fields.view.zipWithIndex.foreach(f => {
          row.put(f._1.n, ViewSerDe.deserializeField(f._1.t, rs.getString(f._2 + 1)))
        })
        res.append(row.toMap)
      }
    }
    finally {
      if (rs != null) try {
        rs.close()
      } catch {
        case _: Throwable =>
      }

      if (statement != null) try {
        statement.close()
      } catch {
        case _: Throwable =>
      }
    }

    orderByField match {
      case Some(f) => res.sortBy {
        _ (f.n) match {
          case null => ""
          case other => other.toString
        }
      } toList
      case None => res.toList
    }
  }

  def selectView(v: View, orderByField: Option[FieldLike[_]]): List[Map[String, Any]] =
    selectForViewByQuery(v, HiveQl.selectAll(v), orderByField)

} 
Example 8
Source File: ExasolRDDSuite.scala    From spark-exasol-connector   with Apache License 2.0 5 votes vote down vote up
package com.exasol.spark.rdd

import java.sql.Statement

import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StructType

import com.exasol.jdbc.EXAConnection
import com.exasol.jdbc.EXAResultSet
import com.exasol.spark.util.ExasolConnectionManager

import org.mockito.Mockito._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

class ExasolRDDSuite extends AnyFunSuite with Matchers with MockitoSugar {

  test("`getPartitions` returns correct set of partitions") {
    val sparkContext = mock[SparkContext]
    val mainConnection = mock[EXAConnection]
    val mainStatement = mock[Statement]
    val mainResultSet = mock[EXAResultSet]
    val manager = mock[ExasolConnectionManager]

    val handle: Int = 7

    when(manager.mainConnection).thenReturn(mainConnection)
    when(manager.subConnections(mainConnection)).thenReturn(Seq("url1", "url2"))
    when(mainConnection.createStatement()).thenReturn(mainStatement)
    when(mainStatement.executeQuery("")).thenReturn(mainResultSet)
    when(mainResultSet.GetHandle()).thenReturn(handle)

    val rdd = new ExasolRDD(sparkContext, "", StructType(Nil), manager)
    val partitions = rdd.getPartitions

    assert(partitions.size == 2)
    partitions.zipWithIndex.foreach {
      case (part, idx) =>
        assert(part.index === idx)
        assert(part.isInstanceOf[ExasolRDDPartition])
        assert(part.asInstanceOf[ExasolRDDPartition].handle === handle)
        assert(part.asInstanceOf[ExasolRDDPartition].connectionUrl === s"url${idx + 1}")
    }
    verify(manager, times(1)).mainConnection
    verify(manager, times(1)).subConnections(mainConnection)
  }

  test("`getPartitions` throws exceptions if main connection is null") {
    val sparkContext = mock[SparkContext]
    val manager = mock[ExasolConnectionManager]

    when(manager.mainConnection).thenReturn(null)

    val thrown = intercept[RuntimeException] {
      new ExasolRDD(sparkContext, "", StructType(Nil), manager).getPartitions
    }
    assert(thrown.getMessage === "Could not establish main connection to Exasol!")

    verify(manager, times(1)).mainConnection
  }

} 
Example 9
Source File: ThriftServerBaseTest.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.thriftserver

import java.sql.{Connection, DriverManager, Statement}

import org.apache.hive.jdbc.HiveDriver
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.livy.LivyConf
import org.apache.livy.LivyConf.{LIVY_SPARK_SCALA_VERSION, LIVY_SPARK_VERSION}
import org.apache.livy.server.AccessManager
import org.apache.livy.server.recovery.{SessionStore, StateStore}
import org.apache.livy.sessions.InteractiveSessionManager
import org.apache.livy.utils.LivySparkUtils.{formatSparkVersion, sparkScalaVersion, sparkSubmitVersion}

object ServerMode extends Enumeration {
  val binary, http = Value
}

abstract class ThriftServerBaseTest extends FunSuite with BeforeAndAfterAll {
  def mode: ServerMode.Value
  def port: Int

  val THRIFT_SERVER_STARTUP_TIMEOUT = 30000 // ms

  val livyConf = new LivyConf()
  val (sparkVersion, scalaVersionFromSparkSubmit) = sparkSubmitVersion(livyConf)
  val formattedSparkVersion: (Int, Int) = {
    formatSparkVersion(sparkVersion)
  }

  def jdbcUri(defaultDb: String, sessionConf: String*): String = if (mode == ServerMode.http) {
    s"jdbc:hive2://localhost:$port/$defaultDb?hive.server2.transport.mode=http;" +
      s"hive.server2.thrift.http.path=cliservice;${sessionConf.mkString(";")}"
  } else {
    s"jdbc:hive2://localhost:$port/$defaultDb?${sessionConf.mkString(";")}"
  }

  override def beforeAll(): Unit = {
    Class.forName(classOf[HiveDriver].getCanonicalName)
    livyConf.set(LivyConf.THRIFT_TRANSPORT_MODE, mode.toString)
    livyConf.set(LivyConf.THRIFT_SERVER_PORT, port)

    // Set formatted Spark and Scala version into livy configuration, this will be used by
    // session creation.
    livyConf.set(LIVY_SPARK_VERSION.key, formattedSparkVersion.productIterator.mkString("."))
    livyConf.set(LIVY_SPARK_SCALA_VERSION.key,
      sparkScalaVersion(formattedSparkVersion, scalaVersionFromSparkSubmit, livyConf))
    StateStore.init(livyConf)

    val ss = new SessionStore(livyConf)
    val sessionManager = new InteractiveSessionManager(livyConf, ss)
    val accessManager = new AccessManager(livyConf)
    LivyThriftServer.start(livyConf, sessionManager, ss, accessManager)
    LivyThriftServer.thriftServerThread.join(THRIFT_SERVER_STARTUP_TIMEOUT)
    assert(LivyThriftServer.getInstance.isDefined)
    assert(LivyThriftServer.getInstance.get.getServiceState == STATE.STARTED)
  }

  override def afterAll(): Unit = {
    LivyThriftServer.stopServer()
  }

  def withJdbcConnection(f: (Connection => Unit)): Unit = {
    withJdbcConnection("default", Seq.empty)(f)
  }

  def withJdbcConnection(db: String, sessionConf: Seq[String])(f: (Connection => Unit)): Unit = {
    withJdbcConnection(jdbcUri(db, sessionConf: _*))(f)
  }

  def withJdbcConnection(uri: String)(f: (Connection => Unit)): Unit = {
    val user = System.getProperty("user.name")
    val connection = DriverManager.getConnection(uri, user, "")
    try {
      f(connection)
    } finally {
      connection.close()
    }
  }

  def withJdbcStatement(f: (Statement => Unit)): Unit = {
    withJdbcConnection { connection =>
      val s = connection.createStatement()
      try {
        f(s)
      } finally {
        s.close()
      }
    }
  }
} 
Example 10
Source File: JdbcConnector.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors

import java.sql.{Connection, Statement}

import com.aol.one.dwh.infra.sql.Setting
import com.aol.one.dwh.infra.sql.pool.HikariConnectionPool
import com.aol.one.dwh.infra.util.LogTrait
import com.aol.one.dwh.infra.sql.Query
import com.aol.one.dwh.infra.sql.pool.SqlSource.{PRESTO, VERTICA}
import com.facebook.presto.jdbc.PrestoConnection
import com.google.common.cache.CacheBuilder
import com.vertica.jdbc.VerticaConnection
import org.apache.commons.dbutils.ResultSetHandler
import resource.managed

import scala.concurrent.duration._
import scala.util.Try
import scalacache.guava.GuavaCache
import scalacache.memoization._
import scalacache.{CacheConfig, ScalaCache}


abstract class JdbcConnector(@cacheKeyExclude pool: HikariConnectionPool) extends LogTrait {

  implicit val scalaCache = ScalaCache(
    GuavaCache(CacheBuilder.newBuilder().maximumSize(100).build[String, Object]),
    cacheConfig = CacheConfig(keyPrefix = Some(pool.getName))
  )

  def runQuery[V](query: Query, @cacheKeyExclude handler: ResultSetHandler[V]): V = memoizeSync(50.seconds) {
    val rm =
      for {
        connection <- managed(pool.getConnection)
        statement  <- managed(connection.createStatement())
      } yield {
        applySettings(connection, statement, query.settings)
        logger.info(s"Running query:[${query.sql}] source:[${query.source}] settings:[${query.settings.mkString(",")}]")
        val resultSet = statement.executeQuery(query.sql)
        handler.handle(resultSet)
      }

    Try(rm.acquireAndGet(identity)).getOrElse(throw new RuntimeException(s"Failure:[$query]"))
  }

  private def applySettings(connection: Connection, statement: Statement, settings: Seq[Setting]) = {
    settings.foreach(setting => applySetting(connection, statement, setting))
  }

  def applySetting(connection: Connection, statement: Statement, setting: Setting)

}

object JdbcConnector {

  private class PrestoConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
    override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {
      connection.unwrap(classOf[PrestoConnection]).setSessionProperty(setting.key, setting.value)
    }
  }

  private class VerticaConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
    override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {
      connection.unwrap(classOf[VerticaConnection]).setProperty(setting.key, setting.value)
    }
  }

  def apply(connectorType: String, connectionPool: HikariConnectionPool): JdbcConnector = connectorType match {
    case VERTICA => new VerticaConnector(connectionPool)
    case PRESTO => new PrestoConnector(connectionPool)
    case _ => throw new IllegalArgumentException(s"Can't create connector for SQL source:[$connectorType]")
  }
} 
Example 11
Source File: JdbcConnectorTest.scala    From bandar-log   with Apache License 2.0 5 votes vote down vote up
package com.aol.one.dwh.bandarlog.connectors

import java.sql.{Connection, DatabaseMetaData, ResultSet, Statement}

import com.aol.one.dwh.infra.config._
import com.aol.one.dwh.infra.sql.pool.HikariConnectionPool
import com.aol.one.dwh.infra.sql.{ListStringResultHandler, Setting, VerticaMaxValuesQuery}
import org.apache.commons.dbutils.ResultSetHandler
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class JdbcConnectorTest extends FunSuite with MockitoSugar {

  private val statement = mock[Statement]
  private val resultSet = mock[ResultSet]
  private val connectionPool = mock[HikariConnectionPool]
  private val connection = mock[Connection]
  private val databaseMetaData = mock[DatabaseMetaData]
  private val resultSetHandler = mock[ResultSetHandler[Long]]
  private val listStringResultHandler = mock[ListStringResultHandler]

  test("check run query result for numeric batch_id column") {
    val resultValue = 100L
    val table = Table("table", List("column"), None)
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT MAX(column) AS column FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(resultSetHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, resultSetHandler)

    assert(result == resultValue)
  }

  test("check run query result for date/time partitions") {
    val resultValue = Some(20190924L)
    val table = Table("table", List("year", "month", "day"), Some(List("yyyy", "MM", "dd")))
    val query = VerticaMaxValuesQuery(table)
    when(connectionPool.getConnection).thenReturn(connection)
    when(connectionPool.getName).thenReturn("connection_pool_name")
    when(connection.createStatement()).thenReturn(statement)
    when(statement.executeQuery("SELECT DISTINCT year, month, day FROM table")).thenReturn(resultSet)
    when(connection.getMetaData).thenReturn(databaseMetaData)
    when(databaseMetaData.getURL).thenReturn("connection_url")
    when(listStringResultHandler.handle(resultSet)).thenReturn(resultValue)

    val result = new DefaultJdbcConnector(connectionPool).runQuery(query, listStringResultHandler)

    assert(result == resultValue)
  }
}

class DefaultJdbcConnector(connectionPool: HikariConnectionPool) extends JdbcConnector(connectionPool) {
  override def applySetting(connection: Connection, statement: Statement, setting: Setting): Unit = {}
} 
Example 12
Source File: H2Utils.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.testing.utils

import java.sql.{Connection, DriverManager, ResultSet, Statement}

import org.apache.spark.sql._
import org.opencypher.morpheus.api.io.sql.SqlDataSourceConfig

object H2Utils {

  implicit class ConnOps(conn: Connection) {
    def run[T](code: Statement => T): T = {
      val stmt = conn.createStatement()
      try { code(stmt) } finally { stmt.close() }
    }
    def execute(sql: String): Boolean = conn.run(_.execute(sql))
    def query(sql: String): ResultSet = conn.run(_.executeQuery(sql))
    def update(sql: String): Int = conn.run(_.executeUpdate(sql))
  }

  def withConnection[T](cfg: SqlDataSourceConfig.Jdbc)(code: Connection => T): T = {
    Class.forName(cfg.driver)
    val conn = (cfg.options.get("user"), cfg.options.get("password")) match {
      case (Some(user), Some(pass)) =>
        DriverManager.getConnection(cfg.url, user, pass)
      case _ =>
        DriverManager.getConnection(cfg.url)
    }
    try { code(conn) } finally { conn.close() }
  }

  implicit class DataFrameWriterOps(write: DataFrameWriter[Row]) {
    def maybeOption(key: String, value: Option[String]): DataFrameWriter[Row] =
      value.fold(write)(write.option(key, _))
  }

  implicit class DataFrameSqlOps(df: DataFrame) {

    def saveAsSqlTable(cfg: SqlDataSourceConfig.Jdbc, tableName: String): Unit =
      df.write
        .mode(SaveMode.Overwrite)
        .format("jdbc")
        .option("url", cfg.url)
        .option("driver", cfg.driver)
        .options(cfg.options)
        .option("dbtable", tableName)
        .save()
  }
} 
Example 13
Source File: HiveJDBCUtils.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.hive.utilities

import java.security._
import java.sql.{Connection, DriverManager, Statement}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.UserGroupInformation

import com.paypal.gimel.common.conf.{GimelConstants, GimelProperties}

object HiveJDBCUtils {

  def apply(conf: GimelProperties, cluster: String): HiveJDBCUtils = {
    new HiveJDBCUtils(conf, cluster)
  }
}

class HiveJDBCUtils(val props: GimelProperties, cluster: String = "unknown_cluster") {
  val logger = com.paypal.gimel.logger.Logger()

  logger.info("Using Supplied KeyTab to authenticate KDC...")
  val conf = new Configuration
  conf.set(GimelConstants.SECURITY_AUTH, "kerberos")
  UserGroupInformation.setConfiguration(conf)
  val ugi: UserGroupInformation = UserGroupInformation.loginUserFromKeytabAndReturnUGI(props.principal, props.keytab)
  UserGroupInformation.setLoginUser(ugi)


  
  def withStatement(fn: Statement => Any): Any = {
    def MethodName: String = new Exception().getStackTrace.apply(1).getMethodName

    logger.info(" @Begin --> " + MethodName)
    withConnection {
      connection =>
        val statement = connection.createStatement
        var output: Any = None
        try {
          output = fn(statement)
        } catch {
          case e: Throwable =>
            e.printStackTrace
            throw e
        }
        finally {
          if (!statement.isClosed) {
            statement.close
          }
        }
        output
    }
  }

}