org.apache.spark.sql.jdbc.JdbcDialect Scala Examples

The following examples show how to use org.apache.spark.sql.jdbc.JdbcDialect. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MemsqlDialect.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._

case object MemsqlDialect extends JdbcDialect {
  override def canHandle(url: String): Boolean = url.startsWith("jdbc:memsql")

  val MEMSQL_DECIMAL_MAX_SCALE = 30

  override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
    case BooleanType   => Option(JdbcType("BOOL", Types.BOOLEAN))
    case ByteType      => Option(JdbcType("TINYINT", Types.TINYINT))
    case ShortType     => Option(JdbcType("SMALLINT", Types.SMALLINT))
    case FloatType     => Option(JdbcType("FLOAT", Types.FLOAT))
    case TimestampType => Option(JdbcType("TIMESTAMP(6)", Types.TIMESTAMP))
    case dt: DecimalType if (dt.scale <= MEMSQL_DECIMAL_MAX_SCALE) =>
      Option(JdbcType(s"DECIMAL(${dt.precision}, ${dt.scale})", Types.DECIMAL))
    case dt: DecimalType =>
      throw new IllegalArgumentException(
        s"Too big scale specified(${dt.scale}). MemSQL DECIMAL maximum scale is ${MEMSQL_DECIMAL_MAX_SCALE}")
    case NullType =>
      throw new IllegalArgumentException(
        "No corresponding MemSQL type found for NullType. If you want to use NullType, please write to an already existing MemSQL table.")
    case t => JdbcUtils.getCommonJDBCType(t)
  }

  override def getCatalystType(sqlType: Int,
                               typeName: String,
                               size: Int,
                               md: MetadataBuilder): Option[DataType] = {
    (sqlType, typeName) match {
      case (Types.REAL, "FLOAT")        => Option(FloatType)
      case (Types.BIT, "BIT")           => Option(BinaryType)
      case (Types.TINYINT, "TINYINT")   => Option(ShortType)
      case (Types.SMALLINT, "SMALLINT") => Option(ShortType)
      case (Types.DECIMAL, "DECIMAL") => {
        if (size > DecimalType.MAX_PRECISION) {
          throw new IllegalArgumentException(
            s"DECIMAL precision ${size} exceeds max precision ${DecimalType.MAX_PRECISION}")
        } else {
          Option(
            DecimalType(size, md.build().getLong("scale").toInt)
          )
        }
      }
      case _ => None
    }
  }

  override def quoteIdentifier(colName: String): String = {
    s"`$colName`"
  }

  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
} 
Example 2
Source File: JdbcUtil.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.jdbc

import java.sql.{Connection, PreparedStatement}
import java.util.Locale

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


object JdbcUtil {

  def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  def makeSetter(
    conn: Connection,
    dialect: JdbcDialect,
    dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val strValue = row.get(pos) match {
          case str: UTF8String => str.toString
          case str: String => str
        }
        stmt.setString(pos + 1, strValue)

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase(Locale.ROOT).split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
}