org.apache.spark.sql.catalyst.analysis.TypeCheckResult Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.analysis.TypeCheckResult. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SortOrder.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator

abstract sealed class SortDirection
case object Ascending extends SortDirection
case object Descending extends SortDirection


case class SortPrefix(child: SortOrder) extends UnaryExpression {

  override def eval(input: InternalRow): Any = throw new UnsupportedOperationException

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val childCode = child.child.gen(ctx)
    val input = childCode.primitive
    val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
    val DoublePrefixCmp = classOf[DoublePrefixComparator].getName

    val (nullValue: Long, prefixCode: String) = child.child.dataType match {
      case BooleanType =>
        (Long.MinValue, s"$input ? 1L : 0L")
      case _: IntegralType =>
        (Long.MinValue, s"(long) $input")
      case DateType | TimestampType =>
        (Long.MinValue, s"(long) $input")
      case FloatType | DoubleType =>
        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
          s"$DoublePrefixCmp.computePrefix((double)$input)")
      case StringType => (0L, s"$input.getPrefix()")
      case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
      case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
        val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
          s"$input.toUnscaledLong()"
        } else {
          // reduce the scale to fit in a long
          val p = Decimal.MAX_LONG_DIGITS
          val s = p - (dt.precision - dt.scale)
          s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
        }
        (Long.MinValue, prefix)
      case dt: DecimalType =>
        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
          s"$DoublePrefixCmp.computePrefix($input.toDouble())")
      case _ => (0L, "0L")
    }

    childCode.code +
    s"""
      |long ${ev.primitive} = ${nullValue}L;
      |boolean ${ev.isNull} = false;
      |if (!${childCode.isNull}) {
      |  ${ev.primitive} = $prefixCode;
      |}
    """.stripMargin
  }

  override def dataType: DataType = LongType
} 
Example 2
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 3
Source File: TypeUtils.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.distinct.size > 1) {
      TypeCheckResult.TypeCheckFailure(
        s"input to $caller should all be the same type, but it's " +
          types.map(_.simpleString).mkString("[", ", ", "]"))
    } else {
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val res = x(i).compareTo(y(i))
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 4
Source File: SortOrder.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator

abstract sealed class SortDirection
case object Ascending extends SortDirection
case object Descending extends SortDirection


case class SortPrefix(child: SortOrder) extends UnaryExpression {

  override def eval(input: InternalRow): Any = throw new UnsupportedOperationException

  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    val childCode = child.child.gen(ctx)
    val input = childCode.value
    val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
    val DoublePrefixCmp = classOf[DoublePrefixComparator].getName

    val (nullValue: Long, prefixCode: String) = child.child.dataType match {
      case BooleanType =>
        (Long.MinValue, s"$input ? 1L : 0L")
      case _: IntegralType =>
        (Long.MinValue, s"(long) $input")
      case DateType | TimestampType =>
        (Long.MinValue, s"(long) $input")
      case FloatType | DoubleType =>
        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
          s"$DoublePrefixCmp.computePrefix((double)$input)")
      case StringType => (0L, s"$input.getPrefix()")
      case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
      case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
        val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
          s"$input.toUnscaledLong()"
        } else {
          // reduce the scale to fit in a long
          val p = Decimal.MAX_LONG_DIGITS
          val s = p - (dt.precision - dt.scale)
          s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
        }
        (Long.MinValue, prefix)
      case dt: DecimalType =>
        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
          s"$DoublePrefixCmp.computePrefix($input.toDouble())")
      case _ => (0L, "0L")
    }

    childCode.code +
    s"""
      |long ${ev.value} = ${nullValue}L;
      |boolean ${ev.isNull} = false;
      |if (!${childCode.isNull}) {
      |  ${ev.value} = $prefixCode;
      |}
    """.stripMargin
  }

  override def dataType: DataType = LongType
} 
Example 5
Source File: Average.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

case class Average(child: Expression) extends DeclarativeAggregate {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     count.left + count.right
  )

  // If all input are nulls, count will be 0 and we will get null after the division.
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }
} 
Example 6
Source File: monotonicaggregates.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.execution.aggregates

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, Greatest, Least, Literal, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}

abstract class MonotonicAggregateFunction extends DeclarativeAggregate with Serializable {}

case class MMax(child: Expression) extends MonotonicAggregateFunction {

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = child.dataType

  // Expected input data type.
  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function mmax")

  private lazy val mmax = AttributeReference("mmax", child.dataType)()

  override lazy val aggBufferAttributes: Seq[AttributeReference] = mmax :: Nil

  override lazy val initialValues: Seq[Literal] = Seq(
     Least(Seq(mmin.left, mmin.right))
    )
  }

  override lazy val evaluateExpression: AttributeReference = mmin
}

case class MonotonicAggregateExpression(aggregateFunction: MonotonicAggregateFunction,
                                        mode: AggregateMode,
                                        isDistinct: Boolean)
  extends Expression
    with Unevaluable {

  override def children: Seq[Expression] = aggregateFunction :: Nil
  override def dataType: DataType = aggregateFunction.dataType
  override def foldable: Boolean = false
  override def nullable: Boolean = aggregateFunction.nullable

  override def references: AttributeSet = {
    val childReferences = mode match {
      case Partial | Complete => aggregateFunction.references.toSeq
      case PartialMerge | Final => aggregateFunction.aggBufferAttributes
    }

    AttributeSet(childReferences)
  }

  override def prettyString: String = aggregateFunction.prettyString

  override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
} 
Example 7
Source File: JsonGroupArray.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, StringType }
import org.apache.spark.sql.catalyst.expressions.{ 
  AttributeReference, 
  If, 
  StartsWith, 
  Literal, 
  IsNull, 
  Concat, 
  Substring
}

case class JsonGroupArray(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = StringType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function json_group_array")
  private lazy val json_group_array = AttributeReference("json_group_array", StringType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = json_group_array :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create("", StringType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    If(IsNull(child),
      Concat(Seq(json_group_array, Literal(","), Literal("null"))),
      Concat(Seq(json_group_array, Literal(","), org.apache.spark.sql.catalyst.expressions.Cast(child,StringType,None)))) 
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      Concat(Seq(json_group_array.left, json_group_array.right))
    )
  }
  override lazy val evaluateExpression = Concat(Seq(Literal("["), If(StartsWith(json_group_array,Literal(",")),Substring(json_group_array,Literal(2),Literal(Integer.MAX_VALUE)),json_group_array), Literal("]")))
} 
Example 8
Source File: GroupBitwiseOr.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, LongType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, BitwiseOr }

case class GroupBitwiseOr(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = LongType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_bitwise_or")
  private lazy val group_bitwise_or = AttributeReference("group_bitwise_or", LongType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_bitwise_or :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(0, LongType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    BitwiseOr(group_bitwise_or, child)
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      BitwiseOr(group_bitwise_or.left, group_bitwise_or.right)
    )
  }
  override lazy val evaluateExpression: AttributeReference = group_bitwise_or
} 
Example 9
Source File: GroupBitwiseAnd.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, LongType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, BitwiseAnd }

case class GroupBitwiseAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = LongType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_bitwise_and")
  private lazy val group_bitwise_and = AttributeReference("group_bitwise_and", LongType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_bitwise_and :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(0xffffffffffffffffl, LongType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    BitwiseAnd(group_bitwise_and, child)
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      BitwiseAnd(group_bitwise_and.left, group_bitwise_and.right)
    )
  }
  override lazy val evaluateExpression: AttributeReference = group_bitwise_and
} 
Example 10
Source File: GroupAnd.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, BooleanType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, And }

case class GroupAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = BooleanType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_and")
  private lazy val group_and = AttributeReference("group_and", BooleanType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_and :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(true, BooleanType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    And(group_and, child)
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      And(group_and.left, group_and.right)
    )
  }
  override lazy val evaluateExpression: AttributeReference = group_and
} 
Example 11
Source File: GroupOr.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, BooleanType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, Or }


case class GroupOr(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = BooleanType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_or")
  private lazy val group_or = AttributeReference("group_or", BooleanType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_or :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(false, BooleanType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    Or(group_or, child)
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      Or(group_or.left, group_or.right)
    )
  }
  override lazy val evaluateExpression: AttributeReference = group_or
} 
Example 12
Source File: TypeUtils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.size <= 1) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      val firstType = types.head
      types.foreach { t =>
        if (!t.sameType(firstType)) {
          return TypeCheckResult.TypeCheckFailure(
            s"input to $caller should all be the same type, but it's " +
              types.map(_.simpleString).mkString("[", ", ", "]"))
        }
      }
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val v1 = x(i) & 0xff
      val v2 = y(i) & 0xff
      val res = v1 - v2
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 13
Source File: Average.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     count.left + count.right
  )

  // If all input are nulls, count will be 0 and we will get null after the division.
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
        resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }
} 
Example 14
Source File: StatefulApproxQuantile.scala    From deequ   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.types._


private[sql] case class StatefulApproxQuantile(
    child: Expression,
    accuracyExpression: Expression,
    override val mutableAggBufferOffset: Int,
    override val inputAggBufferOffset: Int)
  extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {

  def this(child: Expression, accuracyExpression: Expression) = {
    this(child, accuracyExpression, 0, 0)
  }

  def this(child: Expression) = {
    this(child, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
  }

  // Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
  private lazy val accuracy: Double = accuracyExpression.eval().asInstanceOf[Double]

  override def inputTypes: Seq[AbstractDataType] = {
    Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
  }

  override def checkInputDataTypes(): TypeCheckResult = {
    val defaultCheck = super.checkInputDataTypes()
    if (defaultCheck.isFailure) {
      defaultCheck
    } else if (!accuracyExpression.foldable) {
      TypeCheckFailure(s"The accuracy provided must be a constant literal")
    } else if (accuracy <= 0) {
      TypeCheckFailure(
        s"The accuracy provided must be a positive integer literal (current value = $accuracy)")
    } else {
      TypeCheckSuccess
    }
  }

  override def createAggregationBuffer(): PercentileDigest = {
    val relativeError = 1.0D / accuracy
    new PercentileDigest(relativeError)
  }

  override def update(buffer: PercentileDigest, inputRow: InternalRow): PercentileDigest = {
    val value = child.eval(inputRow)
    // Ignore empty rows, for example: percentile_approx(null)
    if (value != null) {
      buffer.add(value.asInstanceOf[Double])
    }
    buffer
  }

  override def merge(buffer: PercentileDigest, other: PercentileDigest): PercentileDigest = {
    buffer.merge(other)
    buffer
  }

  override def eval(buffer: PercentileDigest): Any = {
    // instead of evaluating the PercentileDigest quantile summary here,
    // serialize the digest and return it as byte array
    serialize(buffer)
  }

  override def withNewMutableAggBufferOffset(newOffset: Int): StatefulApproxQuantile =
    copy(mutableAggBufferOffset = newOffset)

  override def withNewInputAggBufferOffset(newOffset: Int): StatefulApproxQuantile =
    copy(inputAggBufferOffset = newOffset)

  override def children: Seq[Expression] = Seq(child, accuracyExpression)

  // Returns null for empty inputs
  override def nullable: Boolean = true

  override def dataType: DataType = BinaryType

  override def prettyName: String = "percentile_approx"

  override def serialize(digest: PercentileDigest): Array[Byte] = {
    ApproximatePercentile.serializer.serialize(digest)
  }

  override def deserialize(bytes: Array[Byte]): PercentileDigest = {
    ApproximatePercentile.serializer.deserialize(bytes)
  }
} 
Example 15
Source File: TypeUtils.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.distinct.size > 1) {
      TypeCheckResult.TypeCheckFailure(
        s"input to $caller should all be the same type, but it's " +
          types.map(_.simpleString).mkString("[", ", ", "]"))
    } else {
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val res = x(i).compareTo(y(i))
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 16
Source File: Average.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(x) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     count.left + count.right
  )

  // If all input are nulls, count will be 0 and we will get null after the division.
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }
} 
Example 17
Source File: TypeUtils.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.size <= 1) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      val firstType = types.head
      types.foreach { t =>
        if (!t.sameType(firstType)) {
          return TypeCheckResult.TypeCheckFailure(
            s"input to $caller should all be the same type, but it's " +
              types.map(_.simpleString).mkString("[", ", ", "]"))
        }
      }
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val res = x(i).compareTo(y(i))
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 18
Source File: ReferenceToExpressions.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
      case (childGen, child) =>
        // SPARK-18125: The children vars are local variables. If the result expression uses
        // splitExpression, those variables cannot be accessed so compilation fails.
        // To fix it, we use class variables to hold those local variables.
        val classChildVarName = ctx.freshName("classChildVar")
        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
        ctx.addMutableState("boolean", classChildVarIsNull, "")

        val classChildVar =
          LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)

        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
          s"${classChildVar.isNull} = ${childGen.isNull};"

        (classChildVar, initCode)
    }.unzip

    val resultGen = result.transform {
      case b: BoundReference => classChildrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 19
Source File: Average.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     count.left + count.right
  )

  // If all input are nulls, count will be 0 and we will get null after the division.
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }
} 
Example 20
Source File: TypeUtils.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.size <= 1) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      val firstType = types.head
      types.foreach { t =>
        if (!t.sameType(firstType)) {
          return TypeCheckResult.TypeCheckFailure(
            s"input to $caller should all be the same type, but it's " +
              types.map(_.simpleString).mkString("[", ", ", "]"))
        }
      }
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val res = x(i).compareTo(y(i))
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 21
Source File: ReferenceToExpressions.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
      case (childGen, child) =>
        // SPARK-18125: The children vars are local variables. If the result expression uses
        // splitExpression, those variables cannot be accessed so compilation fails.
        // To fix it, we use class variables to hold those local variables.
        val classChildVarName = ctx.freshName("classChildVar")
        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
        ctx.addMutableState("boolean", classChildVarIsNull, "")

        val classChildVar =
          LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)

        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
          s"${classChildVar.isNull} = ${childGen.isNull};"

        (classChildVar, initCode)
    }.unzip

    val resultGen = result.transform {
      case b: BoundReference => classChildrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 22
Source File: Average.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     count.left + count.right
  )

  // If all input are nulls, count will be 0 and we will get null after the division.
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }
} 
Example 23
Source File: TimeWindow.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class TimeWindow(
    timeColumn: Expression,
    windowDuration: Long,
    slideDuration: Long,
    startTime: Long) extends UnaryExpression
  with ImplicitCastInputTypes
  with Unevaluable
  with NonSQLExpression {

  //////////////////////////
  // SQL Constructors
  //////////////////////////

  def this(
      timeColumn: Expression,
      windowDuration: Expression,
      slideDuration: Expression,
      startTime: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime))
  }

  def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), 0)
  }

  def this(timeColumn: Expression, windowDuration: Expression) = {
    this(timeColumn, windowDuration, windowDuration)
  }

  override def child: Expression = timeColumn
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = new StructType()
    .add(StructField("start", TimestampType))
    .add(StructField("end", TimestampType))

  // This expression is replaced in the analyzer.
  override lazy val resolved = false

  
case class PreciseTimestampConversion(
    child: Expression,
    fromType: DataType,
    toType: DataType) extends UnaryExpression with ExpectsInputTypes {
  override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
  override def dataType: DataType = toType
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      code"""boolean ${ev.isNull} = ${eval.isNull};
         |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
       """.stripMargin)
  }
  override def nullSafeEval(input: Any): Any = input
} 
Example 24
Source File: collect.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import scala.collection.generic.Growable
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._


@ExpressionDescription(
  usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
case class CollectSet(
    child: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {

  def this(child: Expression) = this(child, 0, 0)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data")
    }
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def prettyName: String = "collect_set"

  override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
} 
Example 25
Source File: Average.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

abstract class AverageLike(child: Expression) extends DeclarativeAggregate {

  override def nullable: Boolean = true
  // Return data type.
  override def dataType: DataType = resultType

  private lazy val resultType = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      DecimalType.bounded(p + 4, s + 4)
    case _ => DoubleType
  }

  private lazy val sumDataType = child.dataType match {
    case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
    case _ => DoubleType
  }

  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()

  override lazy val aggBufferAttributes = sum :: count :: Nil

  override lazy val initialValues = Seq(
     If(child.isNull, count, count + 1L)
  )

  override lazy val updateExpressions = updateExpressionsDef
}

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
case class Average(child: Expression)
  extends AverageLike(child) with ImplicitCastInputTypes {

  override def prettyName: String = "avg"

  override def children: Seq[Expression] = child :: Nil

  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function average")
} 
Example 26
Source File: SpecialSum.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import com.pingcap.tispark.utils.ReflectionUtil._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{
  Add,
  AttributeReference,
  Cast,
  Coalesce,
  Expression,
  ExpressionDescription,
  Literal
}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

object PromotedSum {
  def apply(child: Expression): SpecialSum = {
    val retType = child.dataType match {
      case DecimalType.Fixed(precision, scale) =>
        DecimalType.bounded(precision + 10, scale)
      case _ => DoubleType
    }

    SpecialSum(child, retType, null)
  }

  def unapply(s: SpecialSum): Option[Expression] =
    s match {
      case s.initVal if s.initVal == null => Some(s.child)
      case _ => Option.empty[Expression]
    }
}

object SumNotNullable {
  def apply(child: Expression): SpecialSum = {
    val retType = child.dataType match {
      case DecimalType.Fixed(precision, scale) =>
        DecimalType.bounded(precision + 10, scale)
      case _: IntegralType => LongType
      case _ => DoubleType
    }

    SpecialSum(child, retType, 0)
  }

  def unapply(s: SpecialSum): Option[Expression] =
    s match {
      case s.initVal if s.initVal == null => Some(s.child)
      case _ => Option.empty[Expression]
    }
}

@ExpressionDescription(
  usage =
    "_FUNC_(expr) - Returns the sum calculated from values of a group. Result type is promoted to double/decimal.")
case class SpecialSum(child: Expression, retType: DataType, initVal: Any)
    extends DeclarativeAggregate {

  override lazy val aggBufferAttributes: Seq[AttributeReference] = sum :: Nil
  override lazy val initialValues: Seq[Expression] = Seq(
    
      Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)))
  }
  override lazy val evaluateExpression: Expression = sum
  private lazy val resultType = retType
  private lazy val sumDataType = resultType
  private lazy val sum = newAttributeReference("rewriteSum", sumDataType)
  private lazy val zero = Cast(Literal(0), sumDataType)

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  // Return data type.
  override def dataType: DataType = resultType

  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForNumericExpr(child.dataType, "function sum")
} 
Example 27
Source File: TypeUtils.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._


object TypeUtils {
  def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (dt.isInstanceOf[NumericType] || dt == NullType) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
    }
  }

  def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
    if (RowOrdering.isOrderable(dt)) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
    }
  }

  def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
    if (types.size <= 1) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      val firstType = types.head
      types.foreach { t =>
        if (!t.sameType(firstType)) {
          return TypeCheckResult.TypeCheckFailure(
            s"input to $caller should all be the same type, but it's " +
              types.map(_.simpleString).mkString("[", ", ", "]"))
        }
      }
      TypeCheckResult.TypeCheckSuccess
    }
  }

  def getNumeric(t: DataType): Numeric[Any] =
    t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

  def getInterpretedOrdering(t: DataType): Ordering[Any] = {
    t match {
      case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
    }
  }

  def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
    for (i <- 0 until x.length; if i < y.length) {
      val res = x(i).compareTo(y(i))
      if (res != 0) return res
    }
    x.length - y.length
  }
} 
Example 28
Source File: ReferenceToExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val childrenVars = childrenGen.zip(children).map {
      case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
    }

    val resultGen = result.transform {
      case b: BoundReference => childrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
      isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 29
Source File: TimeWindow.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class TimeWindow(
    timeColumn: Expression,
    windowDuration: Long,
    slideDuration: Long,
    startTime: Long) extends UnaryExpression
  with ImplicitCastInputTypes
  with Unevaluable
  with NonSQLExpression {

  //////////////////////////
  // SQL Constructors
  //////////////////////////

  def this(
      timeColumn: Expression,
      windowDuration: Expression,
      slideDuration: Expression,
      startTime: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime))
  }

  def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), 0)
  }

  def this(timeColumn: Expression, windowDuration: Expression) = {
    this(timeColumn, windowDuration, windowDuration)
  }

  override def child: Expression = timeColumn
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = new StructType()
    .add(StructField("start", TimestampType))
    .add(StructField("end", TimestampType))

  // This expression is replaced in the analyzer.
  override lazy val resolved = false

  
case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes {
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = LongType
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      s"""boolean ${ev.isNull} = ${eval.isNull};
         |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
       """.stripMargin)
  }
} 
Example 30
Source File: collect.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.aggregate

import scala.collection.generic.Growable
import scala.collection.mutable

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._


@ExpressionDescription(
  usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
case class CollectSet(
    child: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect {

  def this(child: Expression) = this(child, 0, 0)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data")
    }
  }

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def prettyName: String = "collect_set"

  override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
}