org.apache.spark.sql.catalyst.expressions.EqualTo Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.EqualTo. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: HiveTypeCoercionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 2
Source File: DeltaPushFilter.scala    From connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import scala.collection.immutable.HashSet
import scala.collection.JavaConverters._

import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, SerializationUtilities}
import org.apache.hadoop.hive.ql.lib._
import org.apache.hadoop.hive.ql.parse.SemanticException
import org.apache.hadoop.hive.ql.plan.{ExprNodeColumnDesc, ExprNodeConstantDesc, ExprNodeGenericFuncDesc}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, InSet, LessThan, LessThanOrEqual, Like, Literal, Not}

object DeltaPushFilter extends Logging {
  lazy val supportedPushDownUDFs = Array(
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualNS",
    "org.apache.hadoop.hive.ql.udf.UDFLike",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn"
  )

  def partitionFilterConverter(hiveFilterExprSeriablized: String): Seq[Expression] = {
    if (hiveFilterExprSeriablized != null) {
      val filterExpr = SerializationUtilities.deserializeExpression(hiveFilterExprSeriablized)
      val opRules = new java.util.LinkedHashMap[Rule, NodeProcessor]()
      val nodeProcessor = new NodeProcessor() {
        @throws[SemanticException]
        def process(nd: Node, stack: java.util.Stack[Node],
            procCtx: NodeProcessorCtx, nodeOutputs: Object*): Object = {
          nd match {
            case e: ExprNodeGenericFuncDesc if FunctionRegistry.isOpAnd(e) =>
              nodeOutputs.map(_.asInstanceOf[Expression]).reduce(And)
            case e: ExprNodeGenericFuncDesc =>
              val (columnDesc, constantDesc) =
                if (nd.getChildren.get(0).isInstanceOf[ExprNodeColumnDesc]) {
                  (nd.getChildren.get(0), nd.getChildren.get(1))
                } else { (nd.getChildren.get(1), nd.getChildren.get(0)) }

              val columnAttr = UnresolvedAttribute(
                columnDesc.asInstanceOf[ExprNodeColumnDesc].getColumn)
              val constantVal = Literal(constantDesc.asInstanceOf[ExprNodeConstantDesc].getValue)
              nd.asInstanceOf[ExprNodeGenericFuncDesc].getGenericUDF match {
                case f: GenericUDFOPNotEqualNS =>
                  Not(EqualNullSafe(columnAttr, constantVal))
                case f: GenericUDFOPNotEqual =>
                  Not(EqualTo(columnAttr, constantVal))
                case f: GenericUDFOPEqualNS =>
                  EqualNullSafe(columnAttr, constantVal)
                case f: GenericUDFOPEqual =>
                  EqualTo(columnAttr, constantVal)
                case f: GenericUDFOPGreaterThan =>
                  GreaterThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrGreaterThan =>
                  GreaterThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFOPLessThan =>
                  LessThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrLessThan =>
                  LessThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFBridge if f.getUdfName.equals("like") =>
                  Like(columnAttr, constantVal)
                case f: GenericUDFIn =>
                  val inConstantVals = nd.getChildren.asScala
                    .filter(_.isInstanceOf[ExprNodeConstantDesc])
                    .map(_.asInstanceOf[ExprNodeConstantDesc].getValue)
                    .map(Literal(_)).toSet
                  InSet(columnAttr, HashSet() ++ inConstantVals)
                case _ =>
                  throw new RuntimeException(s"Unsupported func(${nd.getName}) " +
                    s"which can not be pushed down to delta")
              }
            case _ => null
          }
        }
      }

      val disp = new DefaultRuleDispatcher(nodeProcessor, opRules, null)
      val ogw = new DefaultGraphWalker(disp)
      val topNodes = new java.util.ArrayList[Node]()
      topNodes.add(filterExpr)
      val nodeOutput = new java.util.HashMap[Node, Object]()
      try {
        ogw.startWalking(topNodes, nodeOutput)
      } catch {
        case ex: Exception =>
          throw new RuntimeException(ex)
      }
      logInfo(s"converted partition filter expr:" +
        s"${nodeOutput.get(filterExpr).asInstanceOf[Expression].toJSON}")
      Seq(nodeOutput.get(filterExpr).asInstanceOf[Expression])
    } else Seq.empty[org.apache.spark.sql.catalyst.expressions.Expression]
  }
} 
Example 3
Source File: HiveTypeCoercionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 4
Source File: HierarchyAnalysis.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis.Catalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.Node
import org.apache.spark.sql.catalyst.expressions.EqualTo
import org.apache.spark.sql.catalyst.expressions.AttributeReference


  private[this] def supportsExpression(expr: Expression, plan: LogicalPlan): Unit = {
    expr match {
      case np: NodePredicate =>
        (np.left, np.right) match {
          case (l: AttributeReference, r: AttributeReference) =>
            val hl = getReferencedHierarchy(plan, l.exprId)
            val hr = getReferencedHierarchy(plan, r.exprId)
            if (hl.identifier != hr.identifier) {
              throw new AnalysisException(MIXED_NODES_ERROR.format(np.symbol))
            }
          case _ => // OK
        }
      case _ => // OK
    }
    expr.children.foreach(e => supportsExpression(e, plan))
  }

  private def getReferencedHierarchy(plan: LogicalPlan, exprId: ExprId): Hierarchy = {
    plan.collectFirst {
      case h@Hierarchy(_, a) if a.exprId.equals(exprId) => h
    }.get
  }
} 
Example 5
Source File: ResolveHierarchySuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}
import org.apache.spark.sql.catalyst.plans.logical.{AdjacencyListHierarchySpec, Hierarchy}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class ResolveHierarchySuite extends FunSuite with MockitoSugar {

  val br1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("id", IntegerType),
      StructField("parent", IntegerType)
    ))
  }

  val lr1 = LogicalRelation(br1)
  val idAtt = lr1.output.find(_.name == "id").get
  val parentAtt = lr1.output.find(_.name == "parent").get

  test("Check parenthood expression has no conflicting expression IDs and qualifiers") {
    val source = SimpleAnalyzer.execute(lr1.select('id, 'parent).subquery('u))
    assert(source.resolved)

    val hierarchy = Hierarchy(
      AdjacencyListHierarchySpec(source, "v",
        
        UnresolvedAttribute("u" :: "id" :: Nil) === UnresolvedAttribute("v" :: "id" :: Nil),
        Some('id.isNull), Nil),
      'node
    )

    val resolveHierarchy = ResolveHierarchy(SimpleAnalyzer)
    val resolveReferences = ResolveReferencesWithHierarchies(SimpleAnalyzer)

    val resolvedHierarchy = (0 to 10).foldLeft(hierarchy: Hierarchy) { (h, _) =>
      SimpleAnalyzer.ResolveReferences(
        resolveReferences(resolveHierarchy(h))
      ).asInstanceOf[Hierarchy]
    }

    assert(resolvedHierarchy.node.resolved)
    val resolvedSpec = resolvedHierarchy.spec.asInstanceOf[AdjacencyListHierarchySpec]
    assert(resolvedSpec.parenthoodExp.resolved)
    assert(resolvedSpec.startWhere.forall(_.resolved))
    assert(resolvedHierarchy.childrenResolved)
    assert(resolvedHierarchy.resolved)

    val parenthoodExpression = resolvedSpec.parenthoodExp.asInstanceOf[EqualTo]

    assertResult("u" :: Nil)(parenthoodExpression.left.asInstanceOf[Attribute].qualifiers)
    assertResult("v" :: Nil)(parenthoodExpression.right.asInstanceOf[Attribute].qualifiers)
    assert(parenthoodExpression.right.asInstanceOf[Attribute].exprId !=
      source.output.find(_.name == "id").get.exprId)
  }

} 
Example 6
Source File: HiveTypeCoercionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 7
Source File: HiveClientSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.conf.HiveConf

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.types.IntegerType

class HiveClientSuite extends SparkFunSuite {
  private val clientBuilder = new HiveClientBuilder

  private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname

  test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
    val testPartitionCount = 5

    val storageFormat = CatalogStorageFormat(
      locationUri = None,
      inputFormat = None,
      outputFormat = None,
      serde = None,
      compressed = false,
      properties = Map.empty)

    val hadoopConf = new Configuration()
    hadoopConf.setBoolean(tryDirectSqlKey, false)
    val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf)
    client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)")

    val partitions = (1 to testPartitionCount).map { part =>
      CatalogTablePartition(Map("part" -> part.toString), storageFormat)
    }
    client.createPartitions(
      "default", "test", partitions, ignoreIfExists = false)

    val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
      Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3))))

    assert(filteredPartitions.size == testPartitionCount)
  }
} 
Example 8
Source File: HiveTypeCoercionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 9
Source File: HiveTypeCoercionSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }

  test("COALESCE with different types") {
    intercept[RuntimeException] {
      TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
    }
  }
} 
Example 10
Source File: HiveTypeCoercionSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }
  //应该删除布尔值的boolean cast
  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head

    // No cast expression introduced 没有引入表达式
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check  只有一个平等检查
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 11
Source File: HiveTypeCoercionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq(
    ("1", "1"),
    ("1.0", "CAST(1.0 AS DOUBLE)"),
    ("1L", "1L"),
    ("1S", "1S"),
    ("1Y", "1Y"),
    ("'1'", "'1'"))

  baseTypes.foreach { case (ni, si) =>
    baseTypes.foreach { case (nj, sj) =>
      createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { case (i, s) =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $s else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $s end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
      case e: ProjectExec => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 12
Source File: HiveTypeCoercionSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo}
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.hive.test.TestHive


class HiveTypeCoercionSuite extends HiveComparisonTest {
  val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")

  baseTypes.foreach { i =>
    baseTypes.foreach { j =>
      createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
    }
  }

  val nullVal = "null"
  baseTypes.init.foreach { i =>
    createQueryTest(s"case when then $i else $nullVal end ",
      s"SELECT case when true then $i else $nullVal end FROM src limit 1")
    createQueryTest(s"case when then $nullVal else $i end ",
      s"SELECT case when true then $nullVal else $i end FROM src limit 1")
  }

  test("[SPARK-2210] boolean cast on boolean value should be removed") {
    val q = "select cast(cast(key=0 as boolean) as boolean) from src"
    val project = TestHive.sql(q).queryExecution.executedPlan.collect {
      case e: Project => e
    }.head

    // No cast expression introduced
    project.transformAllExpressions { case c: Cast =>
      fail(s"unexpected cast $c")
      c
    }

    // Only one equality check
    var numEquals = 0
    project.transformAllExpressions { case e: EqualTo =>
      numEquals += 1
      e
    }
    assert(numEquals === 1)
  }
} 
Example 13
Source File: ParameterBinderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.sql.SQLException

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder
import org.apache.spark.sql.server.service.ParamBinder
import org.apache.spark.sql.types._

class ParameterBinderSuite extends PlanTest {

  test("bind parameters") {
    val c0 = 'a.int
    val c1 = 'b.int
    val r1 = LocalRelation(c0, c1)

    val param1 = Literal(18, IntegerType)
    val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1)
    val expected1 = Filter(EqualTo(c0, param1), r1)
    comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1)))

    val param2 = Literal(42, IntegerType)
    val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1)
    val expected2 = Filter(EqualTo(c0, param2), r1)
    comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2)))

    val param3 = Literal(-1, IntegerType)
    val param4 = Literal(48, IntegerType)
    val lp3 = Filter(
      And(
        EqualTo(c0, ParameterPlaceHolder(1)),
        EqualTo(c1, ParameterPlaceHolder(2))
      ), r1)
    val expected3 = Filter(
      And(
        EqualTo(c0, param3),
        EqualTo(c1, param4)
      ), r1)
    comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4)))

    val errMsg1 = intercept[SQLException] {
      ParamBinder.bind(lp1, Map.empty)
    }.getMessage
    assert(errMsg1 == "Unresolved parameters found: $1")
    val errMsg2 = intercept[SQLException] {
      ParamBinder.bind(lp2, Map.empty)
    }.getMessage
    assert(errMsg2 == "Unresolved parameters found: $300")
    val errMsg3 = intercept[SQLException] {
      ParamBinder.bind(lp3, Map.empty)
    }.getMessage
    assert(errMsg3 == "Unresolved parameters found: $1, $2")
  }
}