org.apache.spark.sql.catalyst.expressions.Cast Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.Cast. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MessageDelimiter.scala    From spark-cep   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.sources

import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, EmptyRow, Literal}
import org.apache.spark.sql.types.StructType

class MessageDelimiter extends MessageToRowConverter with Logging {
  val delimiter = " "

  def toRow(msg: String, schema: StructType): InternalRow = {
    val splitted = msg.split(delimiter).map(Literal(_))
    val casted = splitted.indices.map(i => Cast(splitted(i), schema(i).dataType).eval(EmptyRow))
    InternalRow.fromSeq(casted)
  }

  def toMessage(row: Row): String = row.mkString(delimiter)
}

trait MessageToRowConverter extends Serializable {
  def toRow(message: String, schema: StructType): InternalRow

  def toMessage(row: Row): String
} 
Example 2
Source File: HiveTypeCoercionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect {
      case e: Project => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 3
Source File: KinesisWriteTask.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import java.nio.ByteBuffer

import com.amazonaws.services.kinesis.producer.{KinesisProducer, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, StringType}

private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, String],
                                        inputSchema: Seq[Attribute]) extends Logging {

  private var producer: KinesisProducer = _
  private val projection = createProjection
  private val streamName = producerConfiguration.getOrElse(
    KinesisSourceProvider.SINK_STREAM_NAME_KEY, "")

  def execute(iterator: Iterator[InternalRow]): Unit = {
    producer = CachedKinesisProducer.getOrCreate(producerConfiguration)
    while (iterator.hasNext) {
      val currentRow = iterator.next()
      val projectedRow = projection(currentRow)
      val partitionKey = projectedRow.getString(0)
      val data = projectedRow.getBinary(1)

      sendData(partitionKey, data)
    }
  }

  def sendData(partitionKey: String, data: Array[Byte]): String = {
    var sentSeqNumbers = new String

    val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data))

    val kinesisCallBack = new FutureCallback[UserRecordResult]() {

      override def onFailure(t: Throwable): Unit = {
        logError(s"Writing to  $streamName failed due to ${t.getCause}")
      }

      override def onSuccess(result: UserRecordResult): Unit = {
        val shardId = result.getShardId
        sentSeqNumbers = result.getSequenceNumber
      }
    }
    Futures.addCallback(future, kinesisCallBack)

    producer.flushSync()
    sentSeqNumbers
  }

  def close(): Unit = {
    if (producer != null) {
      producer.flush()
      producer = null
    }
  }

  private def createProjection: UnsafeProjection = {

    val partitionKeyExpression = inputSchema
      .find(_.name == KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME).getOrElse(
      throw new IllegalStateException("Required attribute " +
        s"'${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME}' not found"))

    partitionKeyExpression.dataType match {
      case StringType | BinaryType => // ok
      case t =>
        throw new IllegalStateException(s"${KinesisWriter.PARTITION_KEY_ATTRIBUTE_NAME} " +
          "attribute type must be a String or BinaryType")
    }

    val dataExpression = inputSchema.find(_.name == KinesisWriter.DATA_ATTRIBUTE_NAME).getOrElse(
      throw new IllegalStateException("Required attribute " +
        s"'${KinesisWriter.DATA_ATTRIBUTE_NAME}' not found")
    )

    dataExpression.dataType match {
      case StringType | BinaryType => // ok
      case t =>
        throw new IllegalStateException(s"${KinesisWriter.DATA_ATTRIBUTE_NAME} " +
          "attribute type must be a String or BinaryType")
    }

    UnsafeProjection.create(
      Seq(Cast(partitionKeyExpression, StringType), Cast(dataExpression, StringType)), inputSchema)
  }

} 
Example 4
Source File: ResolveInlineTablesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}


class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables(conf).validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables(conf).validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(conf)(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables(conf).convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("convert TimeZoneAwareExpression") {
    val table = UnresolvedInlineTable(Seq("c1"),
      Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
    val withTimeZone = ResolveTimeZone(conf).apply(table)
    val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
    val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
      .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
    assert(output.map(_.dataType) == Seq(TimestampType))
    assert(data.size == 1)
    assert(data.head.getLong(0) == correct)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables(conf).convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables(conf).convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 5
Source File: HiveTypeCoercionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 6
Source File: HiveTypeCoercionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }
  //应该删除布尔值的boolean cast
  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

    // No cast expression introduced 没有引入表达式
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check  只有一个平等检查
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 7
Source File: DateUtils.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import java.sql.Date
import java.text.SimpleDateFormat
import java.util.{Calendar, TimeZone}

import org.apache.spark.sql.catalyst.expressions.Cast


object DateUtils {
  private val MILLIS_PER_DAY = 86400000

  // Java TimeZone has no mention of thread safety. Use thread local instance to be safe.
  private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] {
    override protected def initialValue: TimeZone = {
      Calendar.getInstance.getTimeZone
    }
  }

  private def javaDateToDays(d: Date): Int = {
    millisToDays(d.getTime)
  }

  // we should use the exact day as Int, for example, (year, month, day) -> day
  def millisToDays(millisLocal: Long): Int = {
    ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt
  }

  private def toMillisSinceEpoch(days: Int): Long = {
    val millisUtc = days.toLong * MILLIS_PER_DAY
    millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc)
  }

  def fromJavaDate(date: java.sql.Date): Int = {
    javaDateToDays(date)
  }

  def toJavaDate(daysSinceEpoch: Int): java.sql.Date = {
    new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch))
  }

  def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days))

  def stringToTime(s: String): java.util.Date = {
    if (!s.contains('T')) {
      // JDBC escape string
      if (s.contains(' ')) {
        java.sql.Timestamp.valueOf(s)
      } else {
        java.sql.Date.valueOf(s)
      }
    } else if (s.endsWith("Z")) {
      // this is zero timezone of ISO8601
      stringToTime(s.substring(0, s.length - 1) + "GMT-00:00")
    } else if (s.indexOf("GMT") == -1) {
      // timezone with ISO8601
      val inset = "+00.00".length
      val s0 = s.substring(0, s.length - inset)
      val s1 = s.substring(s.length - inset, s.length)
      if (s0.substring(s0.lastIndexOf(':')).contains('.')) {
        stringToTime(s0 + "GMT" + s1)
      } else {
        stringToTime(s0 + ".0GMT" + s1)
      }
    } else {
      // ISO8601 with GMT insert
      val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" )
      ISO8601GMT.parse(s)
    }
  }
} 
Example 8
Source File: HiveTypeCoercionSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }

  test("COALESCE with different types") {
    intercept[RuntimeException] {
      TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
    }
  }
} 
Example 9
Source File: ResolveInlineTables.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 10
Source File: HiveTypeCoercionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 11
Source File: HiveTypeCoercionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 12
Source File: CarbonExpressions.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.types.DataType


  object CarbonScalaUDF {
    def unapply(expression: Expression): Option[(ScalaUDF)] = {
      expression match {
        case a: ScalaUDF =>
          Some(a)
        case _ =>
          None
      }
    }
  }
} 
Example 13
Source File: ResolveInlineTables.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 14
Source File: HiveTypeCoercionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 15
Source File: GlobalSapSQLContext.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import java.io.File

import com.sap.spark.util.TestUtils
import com.sap.spark.{GlobalSparkContext, WithSQLContext}
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast}
import org.apache.spark.unsafe.types._
import org.apache.spark.sql.types._
import org.scalatest.Suite

import scala.io.Source

trait GlobalSapSQLContext extends GlobalSparkContext with WithSQLContext {
  self: Suite =>

  override implicit def sqlContext: SQLContext = GlobalSapSQLContext._sqlc

  override protected def setUpSQLContext(): Unit =
    GlobalSapSQLContext.init(sc)

  override protected def tearDownSQLContext(): Unit =
    GlobalSapSQLContext.reset()

  def getDataFrameFromSourceFile(sparkSchema: StructType, path: File): DataFrame = {
    val conversions = sparkSchema.toSeq.zipWithIndex.map({
      case (field, index) =>
        Cast(BoundReference(index, StringType, nullable = true), field.dataType)
    })
    val data = Source.fromFile(path)
      .getLines()
      .map({ line =>
      val stringRow = InternalRow.fromSeq(line.split(",", -1).map(UTF8String.fromString))
      Row.fromSeq(conversions.map({ c => c.eval(stringRow) }))
    })
    val rdd = sc.parallelize(data.toSeq, numberOfSparkWorkers)
    sqlContext.createDataFrame(rdd, sparkSchema)
  }
}

object GlobalSapSQLContext {

  private var _sqlc: SQLContext = _

  private def init(sc: SparkContext): Unit =
    if (_sqlc == null) {
      _sqlc = TestUtils.newSQLContext(sc)
    }

  private def reset(): Unit = {
    if (_sqlc != null) {
      _sqlc.catalog.unregisterAllTables()
    }
  }

} 
Example 16
Source File: ResolveInlineTablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}


class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables(conf).validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables(conf).validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(conf)(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables(conf).convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("convert TimeZoneAwareExpression") {
    val table = UnresolvedInlineTable(Seq("c1"),
      Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
    val withTimeZone = ResolveTimeZone(conf).apply(table)
    val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
    val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
      .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
    assert(output.map(_.dataType) == Seq(TimestampType))
    assert(data.size == 1)
    assert(data.head.getLong(0) == correct)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables(conf).convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables(conf).convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 17
Source File: view.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


object EliminateView extends Rule[LogicalPlan] {
  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // The child should have the same output attributes with the View operator, so we simply
    // remove the View operator.
    case View(_, output, child) =>
      assert(output == child.output,
        s"The output of the child ${child.output.mkString("[", ",", "]")} is different from the " +
          s"view output ${output.mkString("[", ",", "]")}")
      child
  }
} 
Example 18
Source File: HiveTypeCoercionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 19
Source File: ResolveInlineTables.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
}