org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: AvroDataToCatalyst.scala    From spark-schema-registry   with Apache License 2.0 6 votes vote down vote up
package com.hortonworks.spark.registry.avro

import java.io.ByteArrayInputStream

import com.hortonworks.registries.schemaregistry.{SchemaVersionInfo, SchemaVersionKey}
import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient
import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotDeserializer
import org.apache.avro.Schema
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

import scala.collection.JavaConverters._


case class AvroDataToCatalyst(child: Expression, schemaName: String, version: Option[Int], config: Map[String, Object])
  extends UnaryExpression with ExpectsInputTypes {

  override def inputTypes = Seq(BinaryType)

  @transient private lazy val srDeser: AvroSnapshotDeserializer = {
    val obj = new AvroSnapshotDeserializer()
    obj.init(config.asJava)
    obj
  }

  @transient private lazy val srSchema = fetchSchemaVersionInfo(schemaName, version)

  @transient private lazy val avroSchema = new Schema.Parser().parse(srSchema.getSchemaText)

  override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType

  @transient private lazy val avroDeser= new AvroDeserializer(avroSchema, dataType)

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any = {
    val binary = input.asInstanceOf[Array[Byte]]
    val row = avroDeser.deserialize(srDeser.deserialize(new ByteArrayInputStream(binary), srSchema.getVersion))
    val result = row match {
      case r: InternalRow => r.copy()
      case _ => row
    }
    result
  }

  override def simpleString: String = {
    s"from_sr(${child.sql}, ${dataType.simpleString})"
  }

  override def sql: String = {
    s"from_sr(${child.sql}, ${dataType.catalogString})"
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val expr = ctx.addReferenceObj("this", this)
    defineCodeGen(ctx, ev, input =>
      s"(${ctx.boxedType(dataType)})$expr.nullSafeEval($input)")
  }

  private def fetchSchemaVersionInfo(schemaName: String, version: Option[Int]): SchemaVersionInfo = {
    val srClient = new SchemaRegistryClient(config.asJava)
    version.map(v => srClient.getSchemaVersionInfo(new SchemaVersionKey(schemaName, v)))
      .getOrElse(srClient.getLatestSchemaVersionInfo(schemaName))
  }

} 
Example 2
Source File: ReferenceToExpressions.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
      case (childGen, child) =>
        // SPARK-18125: The children vars are local variables. If the result expression uses
        // splitExpression, those variables cannot be accessed so compilation fails.
        // To fix it, we use class variables to hold those local variables.
        val classChildVarName = ctx.freshName("classChildVar")
        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
        ctx.addMutableState("boolean", classChildVarIsNull, "")

        val classChildVar =
          LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)

        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
          s"${classChildVar.isNull} = ${childGen.isNull};"

        (classChildVar, initCode)
    }.unzip

    val resultGen = result.transform {
      case b: BoundReference => classChildrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 3
Source File: BoundAttribute.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
  extends LeafExpression {

  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

  // Use special getter for primitive types (for UnsafeRow)
  override def eval(input: InternalRow): Any = {
    if (input.isNullAt(ordinal)) {
      null
    } else {
      dataType match {
        case BooleanType => input.getBoolean(ordinal)
        case ByteType => input.getByte(ordinal)
        case ShortType => input.getShort(ordinal)
        case IntegerType | DateType => input.getInt(ordinal)
        case LongType | TimestampType => input.getLong(ordinal)
        case FloatType => input.getFloat(ordinal)
        case DoubleType => input.getDouble(ordinal)
        case StringType => input.getUTF8String(ordinal)
        case BinaryType => input.getBinary(ordinal)
        case CalendarIntervalType => input.getInterval(ordinal)
        case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
        case t: StructType => input.getStruct(ordinal, t.size)
        case _: ArrayType => input.getArray(ordinal)
        case _: MapType => input.getMap(ordinal)
        case _ => input.get(ordinal, dataType)
      }
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = ctx.javaType(dataType)
    val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      val code = oev.code
      oev.code = ""
      ev.copy(code = code)
    } else if (nullable) {
      ev.copy(code = s"""
        boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
        $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""")
    } else {
      ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false")
    }
  }
}

object BindReferences extends Logging {

  def bindReference[A <: Expression](
      expression: A,
      input: AttributeSeq,
      allowFailures: Boolean = false): A = {
    expression.transform { case a: AttributeReference =>
      attachTree(a, "Binding attribute") {
        val ordinal = input.indexOf(a.exprId)
        if (ordinal == -1) {
          if (allowFailures) {
            a
          } else {
            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
          }
        } else {
          BoundReference(ordinal, a.dataType, input(ordinal).nullable)
        }
      }
    }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
  }
} 
Example 4
Source File: decimalExpressions.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any = {
    val d = input.asInstanceOf[Decimal].clone()
    if (d.changePrecision(dataType.precision, dataType.scale)) {
      d
    } else {
      null
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val tmp = ctx.freshName("tmp")
      s"""
         | Decimal $tmp = $eval.clone();
         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
         |   ${ev.value} = $tmp;
         | } else {
         |   ${ev.isNull} = true;
         | }
       """.stripMargin
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql
} 
Example 5
Source File: ReferenceToExpressions.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
      case (childGen, child) =>
        // SPARK-18125: The children vars are local variables. If the result expression uses
        // splitExpression, those variables cannot be accessed so compilation fails.
        // To fix it, we use class variables to hold those local variables.
        val classChildVarName = ctx.freshName("classChildVar")
        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
        ctx.addMutableState("boolean", classChildVarIsNull, "")

        val classChildVar =
          LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)

        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
          s"${classChildVar.isNull} = ${childGen.isNull};"

        (classChildVar, initCode)
    }.unzip

    val resultGen = result.transform {
      case b: BoundReference => classChildrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 6
Source File: ExpressionEvalHelperSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 7
Source File: CheckDeltaInvariant.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta.schema

import org.apache.spark.sql.delta.schema.Invariants.{ArbitraryExpression, NotNull}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, NullType}


case class CheckDeltaInvariant(
    child: Expression,
    invariant: Invariant) extends UnaryExpression with NonSQLExpression {

  override def dataType: DataType = NullType
  override def foldable: Boolean = false
  override def nullable: Boolean = true

  override def flatArguments: Iterator[Any] = Iterator(child)

  private def assertRule(input: InternalRow): Unit = invariant.rule match {
    case NotNull if child.eval(input) == null =>
      throw InvariantViolationException(invariant, "")
    case ArbitraryExpression(expr) =>
      val resolvedExpr = expr.transform {
        case _: UnresolvedAttribute => child
      }
      val result = resolvedExpr.eval(input)
      if (result == null || result == false) {
        throw InvariantViolationException(
          invariant, s"Value ${child.eval(input)} violates requirement.")
      }
  }

  override def eval(input: InternalRow): Any = {
    assertRule(input)
    null
  }

  private def generateNotNullCode(ctx: CodegenContext): Block = {
    val childGen = child.genCode(ctx)
    val invariantField = ctx.addReferenceObj("errMsg", invariant)
    code"""${childGen.code}
       |
       |if (${childGen.isNull}) {
       |  throw org.apache.spark.sql.delta.schema.InvariantViolationException.apply(
       |    $invariantField, "");
       |}
     """.stripMargin
  }

  private def generateExpressionValidationCode(expr: Expression, ctx: CodegenContext): Block = {
    val resolvedExpr = expr.transform {
      case _: UnresolvedAttribute => child
    }
    val elementValue = child.genCode(ctx)
    val childGen = resolvedExpr.genCode(ctx)
    val invariantField = ctx.addReferenceObj("errMsg", invariant)
    val eValue = ctx.freshName("elementResult")
    code"""${elementValue.code}
       |${childGen.code}
       |
       |if (${childGen.isNull} || ${childGen.value} == false) {
       |  Object $eValue = "null";
       |  if (!${elementValue.isNull}) {
       |    $eValue = (Object) ${elementValue.value};
       |  }
       |  throw org.apache.spark.sql.delta.schema.InvariantViolationException.apply(
       |     $invariantField, "Value " + $eValue + " violates requirement.");
       |}
     """.stripMargin
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val code = invariant.rule match {
      case NotNull => generateNotNullCode(ctx)
      case ArbitraryExpression(expr) => generateExpressionValidationCode(expr, ctx)
    }
    ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType))
  }
} 
Example 8
Source File: Serialize.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl.expressions

import java.io.ByteArrayOutputStream

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, _}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.opencypher.morpheus.impl.expressions.EncodeLong.encodeLong
import org.opencypher.morpheus.impl.expressions.Serialize._
import org.opencypher.okapi.impl.exception


case class Serialize(children: Seq[Expression]) extends Expression {

  override def dataType: DataType = BinaryType

  override def nullable: Boolean = false

  // TODO: Only write length if more than one column is serialized
  override def eval(input: InternalRow): Any = {
    // TODO: Reuse from a pool instead of allocating a new one for each serialization
    val out = new ByteArrayOutputStream()
    children.foreach { child =>
      child.dataType match {
        case BinaryType => write(child.eval(input).asInstanceOf[Array[Byte]], out)
        case StringType => write(child.eval(input).asInstanceOf[UTF8String], out)
        case IntegerType => write(child.eval(input).asInstanceOf[Int], out)
        case LongType => write(child.eval(input).asInstanceOf[Long], out)
        case other => throw exception.UnsupportedOperationException(s"Cannot serialize Spark data type $other.")
      }
    }
    out.toByteArray
  }

  override protected def doGenCode(
    ctx: CodegenContext,
    ev: ExprCode
  ): ExprCode = {
    ev.isNull = FalseLiteral
    val out = ctx.freshName("out")
    val serializeChildren = children.map { child =>
      val childEval = child.genCode(ctx)
      s"""|${childEval.code}
          |if (!${childEval.isNull}) {
          |  ${Serialize.getClass.getName.dropRight(1)}.write(${childEval.value}, $out);
          |}""".stripMargin
    }.mkString("\n")
    val baos = classOf[ByteArrayOutputStream].getName
    ev.copy(
      code = code"""|$baos $out = new $baos();
          |$serializeChildren
          |byte[] ${ev.value} = $out.toByteArray();""".stripMargin)
  }

}

object Serialize {

  val supportedTypes: Set[DataType] = Set(BinaryType, StringType, IntegerType, LongType)

  @inline final def write(value: Array[Byte], out: ByteArrayOutputStream): Unit = {
    out.write(encodeLong(value.length))
    out.write(value)
  }

  @inline final def write(
    value: Boolean,
    out: ByteArrayOutputStream
  ): Unit = write(if (value) 1.toLong else 0.toLong, out)

  @inline final def write(value: Byte, out: ByteArrayOutputStream): Unit = write(value.toLong, out)

  @inline final def write(value: Int, out: ByteArrayOutputStream): Unit = write(value.toLong, out)

  @inline final def write(value: Long, out: ByteArrayOutputStream): Unit = write(encodeLong(value), out)

  @inline final def write(value: UTF8String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out)

  @inline final def write(value: String, out: ByteArrayOutputStream): Unit = write(value.getBytes, out)

} 
Example 9
Source File: EncodeLong.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl.expressions

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant, UnaryExpression}
import org.apache.spark.sql.types.{BinaryType, DataType, LongType}
import org.opencypher.morpheus.api.value.MorpheusElement._


case class EncodeLong(child: Expression) extends UnaryExpression with NullIntolerant with ExpectsInputTypes {

  override val dataType: DataType = BinaryType

  override val inputTypes: Seq[LongType] = Seq(LongType)

  override protected def nullSafeEval(input: Any): Any =
    EncodeLong.encodeLong(input.asInstanceOf[Long])

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, c => s"(byte[])(${EncodeLong.getClass.getName.dropRight(1)}.encodeLong($c))")
}

object EncodeLong {

  private final val moreBytesBitMask: Long = Integer.parseInt("10000000", 2)
  private final val varLength7BitMask: Long = Integer.parseInt("01111111", 2)
  private final val otherBitsMask = ~varLength7BitMask
  private final val maxBytesForLongVarEncoding = 10

  // Same encoding as as Base 128 Varints @ https://developers.google.com/protocol-buffers/docs/encoding
  @inline
  final def encodeLong(l: Long): Array[Byte] = {
    val tempResult = new Array[Byte](maxBytesForLongVarEncoding)

    var remainder = l
    var index = 0

    while ((remainder & otherBitsMask) != 0) {
      tempResult(index) = ((remainder & varLength7BitMask) | moreBytesBitMask).toByte
      remainder >>>= 7
      index += 1
    }
    tempResult(index) = remainder.toByte

    val result = new Array[Byte](index + 1)
    System.arraycopy(tempResult, 0, result, 0, index + 1)
    result
  }

  // Same encoding as as Base 128 Varints @ https://developers.google.com/protocol-buffers/docs/encoding
  @inline
  final def decodeLong(input: Array[Byte]): Long = {
    assert(input.nonEmpty, "`decodeLong` requires a non-empty array as its input")
    var index = 0
    var currentByte = input(index)
    var decoded = currentByte & varLength7BitMask
    var nextLeftShift = 7

    while ((currentByte & moreBytesBitMask) != 0) {
      index += 1
      currentByte = input(index)
      decoded |= (currentByte & varLength7BitMask) << nextLeftShift
      nextLeftShift += 7
    }
    assert(index == input.length - 1,
      s"`decodeLong` received an input array ${input.toSeq.toHex} with extra bytes that could not be decoded.")
    decoded
  }

  implicit class ColumnLongOps(val c: Column) extends AnyVal {

    def encodeLongAsMorpheusId(name: String): Column = encodeLongAsMorpheusId.as(name)

    def encodeLongAsMorpheusId: Column = new Column(EncodeLong(c.expr))

  }

} 
Example 10
Source File: MonotonicallyIncreasingID.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}


  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33
  }

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")
  }

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
} 
Example 11
Source File: BoundAttribute.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
  extends LeafExpression {

  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

  // Use special getter for primitive types (for UnsafeRow)
  override def eval(input: InternalRow): Any = {
    if (input.isNullAt(ordinal)) {
      null
    } else {
      dataType match {
        case BooleanType => input.getBoolean(ordinal)
        case ByteType => input.getByte(ordinal)
        case ShortType => input.getShort(ordinal)
        case IntegerType | DateType => input.getInt(ordinal)
        case LongType | TimestampType => input.getLong(ordinal)
        case FloatType => input.getFloat(ordinal)
        case DoubleType => input.getDouble(ordinal)
        case StringType => input.getUTF8String(ordinal)
        case BinaryType => input.getBinary(ordinal)
        case CalendarIntervalType => input.getInterval(ordinal)
        case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
        case t: StructType => input.getStruct(ordinal, t.size)
        case _: ArrayType => input.getArray(ordinal)
        case _: MapType => input.getMap(ordinal)
        case _ => input.get(ordinal, dataType)
      }
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = ctx.javaType(dataType)
    val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      val code = oev.code
      oev.code = ""
      ev.copy(code = code)
    } else if (nullable) {
      ev.copy(code = s"""
        boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
        $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""")
    } else {
      ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false")
    }
  }
}

object BindReferences extends Logging {

  def bindReference[A <: Expression](
      expression: A,
      input: AttributeSeq,
      allowFailures: Boolean = false): A = {
    expression.transform { case a: AttributeReference =>
      attachTree(a, "Binding attribute") {
        val ordinal = input.indexOf(a.exprId)
        if (ordinal == -1) {
          if (allowFailures) {
            a
          } else {
            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
          }
        } else {
          BoundReference(ordinal, a.dataType, input(ordinal).nullable)
        }
      }
    }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
  }
} 
Example 12
Source File: decimalExpressions.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any = {
    val d = input.asInstanceOf[Decimal].clone()
    if (d.changePrecision(dataType.precision, dataType.scale)) {
      d
    } else {
      null
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val tmp = ctx.freshName("tmp")
      s"""
         | Decimal $tmp = $eval.clone();
         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
         |   ${ev.value} = $tmp;
         | } else {
         |   ${ev.isNull} = true;
         | }
       """.stripMargin
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql
} 
Example 13
Source File: MonotonicallyIncreasingID.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}


  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33
  }

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")
  }

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
} 
Example 14
Source File: ExpressionEvalHelperSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 15
Source File: MonotonicallyIncreasingID.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}


  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33
  }

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
    val partitionMaskTerm = "partitionMask"
    ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")
  }

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
} 
Example 16
Source File: BoundAttribute.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
  extends LeafExpression {

  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

  // Use special getter for primitive types (for UnsafeRow)
  override def eval(input: InternalRow): Any = {
    if (input.isNullAt(ordinal)) {
      null
    } else {
      dataType match {
        case BooleanType => input.getBoolean(ordinal)
        case ByteType => input.getByte(ordinal)
        case ShortType => input.getShort(ordinal)
        case IntegerType | DateType => input.getInt(ordinal)
        case LongType | TimestampType => input.getLong(ordinal)
        case FloatType => input.getFloat(ordinal)
        case DoubleType => input.getDouble(ordinal)
        case StringType => input.getUTF8String(ordinal)
        case BinaryType => input.getBinary(ordinal)
        case CalendarIntervalType => input.getInterval(ordinal)
        case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
        case t: StructType => input.getStruct(ordinal, t.size)
        case _: ArrayType => input.getArray(ordinal)
        case _: MapType => input.getMap(ordinal)
        case _ => input.get(ordinal, dataType)
      }
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      ev.copy(code = oev.code)
    } else {
      assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
      val javaType = ctx.javaType(dataType)
      val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
      if (nullable) {
        ev.copy(code =
          s"""
             |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
             |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
           """.stripMargin)
      } else {
        ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
      }
    }
  }
}

object BindReferences extends Logging {

  def bindReference[A <: Expression](
      expression: A,
      input: AttributeSeq,
      allowFailures: Boolean = false): A = {
    expression.transform { case a: AttributeReference =>
      attachTree(a, "Binding attribute") {
        val ordinal = input.indexOf(a.exprId)
        if (ordinal == -1) {
          if (allowFailures) {
            a
          } else {
            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
          }
        } else {
          BoundReference(ordinal, a.dataType, input(ordinal).nullable)
        }
      }
    }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
  }
} 
Example 17
Source File: decimalExpressions.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val tmp = ctx.freshName("tmp")
      s"""
         | Decimal $tmp = $eval.clone();
         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
         |   ${ev.value} = $tmp;
         | } else {
         |   ${ev.isNull} = true;
         | }
       """.stripMargin
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql
} 
Example 18
Source File: ScalaUDFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 19
Source File: ExpressionEvalHelperSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 20
Source File: CatalystDataToAvro.scala    From spark-schema-registry   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.registry.avro

import com.hortonworks.registries.schemaregistry.{SchemaCompatibility, SchemaMetadata}
import com.hortonworks.registries.schemaregistry.avro.AvroSchemaProvider
import com.hortonworks.registries.schemaregistry.client.SchemaRegistryClient
import com.hortonworks.registries.schemaregistry.serdes.avro.AvroSnapshotSerializer
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

import scala.collection.JavaConverters._


case class CatalystDataToAvro(
    child: Expression,
    schemaName: String,
    recordName: String,
    nameSpace: String,
    config: Map[String, Object]
    ) extends UnaryExpression {

  override def dataType: DataType = BinaryType

  private val topLevelRecordName = if (recordName == "") schemaName else recordName

  @transient private lazy val avroType =
    SchemaConverters.toAvroType(child.dataType, child.nullable, topLevelRecordName, nameSpace)

  @transient private lazy val avroSer =
    new AvroSerializer(child.dataType, avroType, child.nullable)

  @transient private lazy val srSer: AvroSnapshotSerializer = {
    val obj = new AvroSnapshotSerializer()
    obj.init(config.asJava)
    obj
  }

  @transient private lazy val srClient = new SchemaRegistryClient(config.asJava)

  @transient private lazy val schemaMetadata = {
    var schemaMetadataInfo = srClient.getSchemaMetadataInfo(schemaName)
    if (schemaMetadataInfo == null) {
      val generatedSchemaMetadata = new SchemaMetadata.Builder(schemaName).
        `type`(AvroSchemaProvider.TYPE)
        .schemaGroup("Autogenerated group")
        .description("Autogenerated schema")
        .compatibility(SchemaCompatibility.BACKWARD).build
      srClient.addSchemaMetadata(generatedSchemaMetadata)
      generatedSchemaMetadata
    } else {
      schemaMetadataInfo.getSchemaMetadata
    }
  }

  override def nullSafeEval(input: Any): Any = {
    val avroData = avroSer.serialize(input)
    srSer.serialize(avroData.asInstanceOf[Object], schemaMetadata)
  }

  override def simpleString: String = {
    s"to_sr(${child.sql}, ${child.dataType.simpleString})"
  }

  override def sql: String = {
    s"to_sr(${child.sql}, ${child.dataType.catalogString})"
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val expr = ctx.addReferenceObj("this", this)
    defineCodeGen(ctx, ev, input =>
      s"(byte[]) $expr.nullSafeEval($input)")
  }
} 
Example 21
Source File: MeanSubstitute.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{ArrayType, NumericType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.sql.dsl._
import io.projectglow.sql.util.RewriteAfterResolution


case class MeanSubstitute(array: Expression, missingValue: Expression)
    extends RewriteAfterResolution {

  override def children: Seq[Expression] = Seq(array, missingValue)

  def this(array: Expression) = {
    this(array, Literal(-1))
  }

  private lazy val arrayElementType = array.dataType.asInstanceOf[ArrayType].elementType

  // A value is considered missing if it is NaN, null or equal to the missing value parameter
  def isMissing(arrayElement: Expression): Predicate =
    IsNaN(arrayElement) || IsNull(arrayElement) || arrayElement === missingValue

  def createNamedStruct(sumValue: Expression, countValue: Expression): Expression = {
    val sumName = Literal(UTF8String.fromString("sum"), StringType)
    val countName = Literal(UTF8String.fromString("count"), StringType)
    namedStruct(sumName, sumValue, countName, countValue)
  }

  // Update sum and count with array element if not missing
  def updateSumAndCountConditionally(
      stateStruct: Expression,
      arrayElement: Expression): Expression = {
    If(
      isMissing(arrayElement),
      // If value is missing, do not update sum and count
      stateStruct,
      // If value is not missing, add to sum and increment count
      createNamedStruct(
        stateStruct.getField("sum") + arrayElement,
        stateStruct.getField("count") + 1)
    )
  }

  // Calculate mean for imputation
  def calculateMean(stateStruct: Expression): Expression = {
    If(
      stateStruct.getField("count") > 0,
      // If non-missing values were found, calculate the average
      stateStruct.getField("sum") / stateStruct.getField("count"),
      // If all values were missing, substitute with missing value
      missingValue
    )
  }

  lazy val arrayMean: Expression = {
    // Sum and count of non-missing values
    array.aggregate(
      createNamedStruct(Literal(0d), Literal(0L)),
      updateSumAndCountConditionally,
      calculateMean
    )
  }

  def substituteWithMean(arrayElement: Expression): Expression = {
    If(isMissing(arrayElement), arrayMean, arrayElement)
  }

  override def rewrite: Expression = {
    if (!array.dataType.isInstanceOf[ArrayType] || !arrayElementType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Can only perform mean substitution on numeric array; provided type is ${array.dataType}.")
    }

    if (!missingValue.dataType.isInstanceOf[NumericType]) {
      throw SQLUtils.newAnalysisException(
        s"Missing value must be of numeric type; provided type is ${missingValue.dataType}.")
    }

    // Replace missing values with the provided strategy
    array.arrayTransform(substituteWithMean(_))
  }
} 
Example 22
Source File: LinearRegressionExpr.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql.expressions

import breeze.linalg.DenseVector
import org.apache.spark.TaskContext
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, TernaryExpression}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._

object LinearRegressionExpr {
  private val matrixUDT = SQLUtils.newMatrixUDT()
  private val state = new ThreadLocal[CovariateQRContext]

  def doLinearRegression(genotypes: Any, phenotypes: Any, covariates: Any): InternalRow = {

    if (state.get() == null) {
      // Save the QR factorization of the covariate matrix since it's the same for every row
      state.set(CovariateQRContext.computeQR(matrixUDT.deserialize(covariates).toDense))
      TaskContext.get().addTaskCompletionListener[Unit](_ => state.remove())
    }

    LinearRegressionGwas.linearRegressionGwas(
      new DenseVector[Double](genotypes.asInstanceOf[ArrayData].toDoubleArray()),
      new DenseVector[Double](phenotypes.asInstanceOf[ArrayData].toDoubleArray()),
      state.get()
    )
  }
}

case class LinearRegressionExpr(
    genotypes: Expression,
    phenotypes: Expression,
    covariates: Expression)
    extends TernaryExpression
    with ImplicitCastInputTypes {

  private val matrixUDT = SQLUtils.newMatrixUDT()

  override def dataType: DataType =
    StructType(
      Seq(
        StructField("beta", DoubleType),
        StructField("standardError", DoubleType),
        StructField("pValue", DoubleType)))

  override def inputTypes: Seq[DataType] =
    Seq(ArrayType(DoubleType), ArrayType(DoubleType), matrixUDT)

  override def children: Seq[Expression] = Seq(genotypes, phenotypes, covariates)

  override protected def nullSafeEval(genotypes: Any, phenotypes: Any, covariates: Any): Any = {
    LinearRegressionExpr.doLinearRegression(genotypes, phenotypes, covariates)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(
      ctx,
      ev,
      (genotypes, phenotypes, covariates) => {
        s"""
         |${ev.value} = io.projectglow.sql.expressions.LinearRegressionExpr.doLinearRegression($genotypes, $phenotypes, $covariates);
       """.stripMargin
      }
    )
  }
} 
Example 23
Source File: inputFileBlock.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = false

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileBlockHolder.getInputFilePath
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
      isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.")
case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_start"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getStartOffset
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
  }
}


@ExpressionDescription(
  usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.")
case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override def prettyName: String = "input_file_block_length"

  override protected def initializeInternal(partitionIndex: Int): Unit = {}

  override protected def evalInternal(input: InternalRow): Long = {
    InputFileBlockHolder.getLength
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
    val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
    ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
  }
} 
Example 24
Source File: MonotonicallyIncreasingID.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}


  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initInternal(): Unit = {
    count = 0L
    partitionMask = TaskContext.getPartitionId().toLong << 33
  }

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.freshName("count")
    val partitionMaskTerm = ctx.freshName("partitionMask")
    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")

    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = "false")
  }

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"
} 
Example 25
Source File: randomExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom


@ExpressionDescription(
  usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.")
case class Randn(seed: Long) extends RDG {
  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  def this() = this(Utils.random.nextLong())

  def this(seed: Expression) = this(seed match {
    case IntegerLiteral(s) => s
    case _ => throw new AnalysisException("Input argument to randn must be an integer literal.")
  })

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val rngTerm = ctx.freshName("rng")
    val className = classOf[XORShiftRandom].getName
    ctx.addMutableState(className, rngTerm,
      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
    ev.copy(code = s"""
      final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
  }
} 
Example 26
Source File: TimeWindow.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class TimeWindow(
    timeColumn: Expression,
    windowDuration: Long,
    slideDuration: Long,
    startTime: Long) extends UnaryExpression
  with ImplicitCastInputTypes
  with Unevaluable
  with NonSQLExpression {

  //////////////////////////
  // SQL Constructors
  //////////////////////////

  def this(
      timeColumn: Expression,
      windowDuration: Expression,
      slideDuration: Expression,
      startTime: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime))
  }

  def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), 0)
  }

  def this(timeColumn: Expression, windowDuration: Expression) = {
    this(timeColumn, windowDuration, windowDuration)
  }

  override def child: Expression = timeColumn
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = new StructType()
    .add(StructField("start", TimestampType))
    .add(StructField("end", TimestampType))

  // This expression is replaced in the analyzer.
  override lazy val resolved = false

  
case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes {
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = LongType
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      s"""boolean ${ev.isNull} = ${eval.isNull};
         |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
       """.stripMargin)
  }
} 
Example 27
Source File: BoundAttribute.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
  extends LeafExpression {

  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

  // Use special getter for primitive types (for UnsafeRow)
  override def eval(input: InternalRow): Any = {
    if (input.isNullAt(ordinal)) {
      null
    } else {
      dataType match {
        case BooleanType => input.getBoolean(ordinal)
        case ByteType => input.getByte(ordinal)
        case ShortType => input.getShort(ordinal)
        case IntegerType | DateType => input.getInt(ordinal)
        case LongType | TimestampType => input.getLong(ordinal)
        case FloatType => input.getFloat(ordinal)
        case DoubleType => input.getDouble(ordinal)
        case StringType => input.getUTF8String(ordinal)
        case BinaryType => input.getBinary(ordinal)
        case CalendarIntervalType => input.getInterval(ordinal)
        case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
        case t: StructType => input.getStruct(ordinal, t.size)
        case _: ArrayType => input.getArray(ordinal)
        case _: MapType => input.getMap(ordinal)
        case _ => input.get(ordinal, dataType)
      }
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = ctx.javaType(dataType)
    val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      val code = oev.code
      oev.code = ""
      ev.copy(code = code)
    } else if (nullable) {
      ev.copy(code = s"""
        boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
        $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""")
    } else {
      ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false")
    }
  }
}

object BindReferences extends Logging {

  def bindReference[A <: Expression](
      expression: A,
      input: AttributeSeq,
      allowFailures: Boolean = false): A = {
    expression.transform { case a: AttributeReference =>
      attachTree(a, "Binding attribute") {
        val ordinal = input.indexOf(a.exprId)
        if (ordinal == -1) {
          if (allowFailures) {
            a
          } else {
            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
          }
        } else {
          BoundReference(ordinal, a.dataType, input(ordinal).nullable)
        }
      }
    }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
  }
} 
Example 28
Source File: decimalExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._


case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any = {
    val d = input.asInstanceOf[Decimal].clone()
    if (d.changePrecision(dataType.precision, dataType.scale)) {
      d
    } else {
      null
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val tmp = ctx.freshName("tmp")
      s"""
         | Decimal $tmp = $eval.clone();
         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
         |   ${ev.value} = $tmp;
         | } else {
         |   ${ev.isNull} = true;
         | }
       """.stripMargin
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql
} 
Example 29
Source File: ReferenceToExpressions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.types.DataType


case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
  extends Expression {

  override def nullable: Boolean = result.nullable
  override def dataType: DataType = result.dataType

  override def checkInputDataTypes(): TypeCheckResult = {
    if (result.references.nonEmpty) {
      return TypeCheckFailure("The result expression cannot reference to any attributes.")
    }

    var maxOrdinal = -1
    result foreach {
      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal > children.length) {
      return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
        s"there are only ${children.length} inputs.")
    }

    TypeCheckSuccess
  }

  private lazy val projection = UnsafeProjection.create(children)

  override def eval(input: InternalRow): Any = {
    result.eval(projection(input))
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childrenGen = children.map(_.genCode(ctx))
    val childrenVars = childrenGen.zip(children).map {
      case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
    }

    val resultGen = result.transform {
      case b: BoundReference => childrenVars(b.ordinal)
    }.genCode(ctx)

    ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
      isNull = resultGen.isNull, value = resultGen.value)
  }
} 
Example 30
Source File: ExpressionEvalHelperSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 31
Source File: package.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import java.util.Collections

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}


    case class ColumnMetrics() {
      val elementTypes = new SetAccumulator[String]
      sparkContext.register(elementTypes)
    }

    val tupleCount: LongAccumulator = sparkContext.longAccumulator

    val numColumns: Int = child.output.size
    val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics())

    def dumpStats(): Unit = {
      debugPrint(s"== ${child.simpleString} ==")
      debugPrint(s"Tuples output: ${tupleCount.value}")
      child.output.zip(columnStats).foreach { case (attr, metric) =>
        // This is called on driver. All accumulator updates have a fixed value. So it's safe to use
        // `asScala` which accesses the internal values using `java.util.Iterator`.
        val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}")
        debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
      }
    }

    protected override def doExecute(): RDD[InternalRow] = {
      child.execute().mapPartitions { iter =>
        new Iterator[InternalRow] {
          def hasNext: Boolean = iter.hasNext

          def next(): InternalRow = {
            val currentRow = iter.next()
            tupleCount.add(1)
            var i = 0
            while (i < numColumns) {
              val value = currentRow.get(i, output(i).dataType)
              if (value != null) {
                columnStats(i).elementTypes.add(value.getClass.getName)
              }
              i += 1
            }
            currentRow
          }
        }
      }
    }

    override def outputPartitioning: Partitioning = child.outputPartitioning

    override def inputRDDs(): Seq[RDD[InternalRow]] = {
      child.asInstanceOf[CodegenSupport].inputRDDs()
    }

    override def doProduce(ctx: CodegenContext): String = {
      child.asInstanceOf[CodegenSupport].produce(ctx, this)
    }

    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
      consume(ctx, input)
    }
  }
} 
Example 32
Source File: subquery.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 33
Source File: TimestampCast.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.codegen.{ CodegenContext, ExprCode, CodeGenerator, JavaCode, Block }
import org.apache.spark.sql.catalyst.expressions.{ Expression, NullIntolerant, UnaryExpression }
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{ DataType, LongType, TimestampType }

case class TimestampToNanos(child: Expression) extends TimestampCast {
  val dataType: DataType = LongType
  protected def cast(childPrim: String): String =
    s"$childPrim * 1000L"
  override protected def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Long] * 1000L
}

case class NanosToTimestamp(child: Expression) extends TimestampCast {
  val dataType: DataType = TimestampType
  protected def cast(childPrim: String): String =
    s"$childPrim / 1000L"
  override protected def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Long] / 1000L
}

object TimestampToNanos {
  
  private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String,
    resultPrim: String, resultNull: String, resultType: DataType): Block = {
    code"""
      boolean $resultNull = $childNull;
      ${CodeGenerator.javaType(resultType)} $resultPrim = ${CodeGenerator.defaultValue(resultType)};
      if (!${childNull}) {
        $resultPrim = (long) ${cast(childPrim)};
      }
    """
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType))
  }
} 
Example 34
Source File: InputFileName.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.rdd.InputFileNameHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String


@ExpressionDescription(
  usage = "_FUNC_() - Returns the name of the current file being read if available",
  extended = "> SELECT _FUNC_();\n ''")
case class InputFileName() extends LeafExpression with Nondeterministic {

  override def nullable: Boolean = true

  override def dataType: DataType = StringType

  override def prettyName: String = "input_file_name"

  override protected def initInternal(): Unit = {}

  override protected def evalInternal(input: InternalRow): UTF8String = {
    InputFileNameHolder.getInputFileName()
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
      "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false")
  }
} 
Example 35
Source File: MonotonicallyIncreasingID.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}


  @transient private[this] var count: Long = _

  @transient private[this] var partitionMask: Long = _

  override protected def initializeInternal(partitionIndex: Int): Unit = {
    count = 0L
    partitionMask = partitionIndex.toLong << 33
  }

  override def nullable: Boolean = false

  override def dataType: DataType = LongType

  override protected def evalInternal(input: InternalRow): Long = {
    val currentCount = count
    count += 1
    partitionMask + currentCount
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
    val partitionMaskTerm = "partitionMask"
    ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

    ev.copy(code = code"""
      final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
      $countTerm++;""", isNull = FalseLiteral)
  }

  override def prettyName: String = "monotonically_increasing_id"

  override def sql: String = s"$prettyName()"

  override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
} 
Example 36
Source File: randomExpressions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom


// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""",
  examples = """
    Examples:
      > SELECT _FUNC_();
       -0.3254147983080288
      > SELECT _FUNC_(0);
       1.1164209726833079
      > SELECT _FUNC_(null);
       1.1164209726833079
  """,
  note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {

  def this() = this(Literal(Utils.random.nextLong(), LongType))

  override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))

  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val className = classOf[XORShiftRandom].getName
    val rngTerm = ctx.addMutableState(className, "rng")
    ctx.addPartitionInitializationStatement(
      s"$rngTerm = new $className(${seed}L + partitionIndex);")
    ev.copy(code = code"""
      final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
      isNull = FalseLiteral)
  }

  override def freshCopy(): Randn = Randn(child)
}

object Randn {
  def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
} 
Example 37
Source File: TimeWindow.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class TimeWindow(
    timeColumn: Expression,
    windowDuration: Long,
    slideDuration: Long,
    startTime: Long) extends UnaryExpression
  with ImplicitCastInputTypes
  with Unevaluable
  with NonSQLExpression {

  //////////////////////////
  // SQL Constructors
  //////////////////////////

  def this(
      timeColumn: Expression,
      windowDuration: Expression,
      slideDuration: Expression,
      startTime: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime))
  }

  def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
    this(timeColumn, TimeWindow.parseExpression(windowDuration),
      TimeWindow.parseExpression(slideDuration), 0)
  }

  def this(timeColumn: Expression, windowDuration: Expression) = {
    this(timeColumn, windowDuration, windowDuration)
  }

  override def child: Expression = timeColumn
  override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
  override def dataType: DataType = new StructType()
    .add(StructField("start", TimestampType))
    .add(StructField("end", TimestampType))

  // This expression is replaced in the analyzer.
  override lazy val resolved = false

  
case class PreciseTimestampConversion(
    child: Expression,
    fromType: DataType,
    toType: DataType) extends UnaryExpression with ExpectsInputTypes {
  override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
  override def dataType: DataType = toType
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    ev.copy(code = eval.code +
      code"""boolean ${ev.isNull} = ${eval.isNull};
         |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
       """.stripMargin)
  }
  override def nullSafeEval(input: Any): Any = input
} 
Example 38
Source File: constraintExpressions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType

case class KnownNotNull(child: Expression) extends UnaryExpression {
  override def nullable: Boolean = false
  override def dataType: DataType = child.dataType

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    child.genCode(ctx).copy(isNull = FalseLiteral)
  }

  override def eval(input: InternalRow): Any = {
    child.eval(input)
  }
} 
Example 39
Source File: BoundAttribute.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._


case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
  extends LeafExpression {

  override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

  private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)

  // Use special getter for primitive types (for UnsafeRow)
  override def eval(input: InternalRow): Any = {
    if (nullable && input.isNullAt(ordinal)) {
      null
    } else {
      accessor(input, ordinal)
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      ev.copy(code = oev.code)
    } else {
      assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
      val javaType = JavaCode.javaType(dataType)
      val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
      if (nullable) {
        ev.copy(code =
          code"""
             |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
             |$javaType ${ev.value} = ${ev.isNull} ?
             |  ${CodeGenerator.defaultValue(dataType)} : ($value);
           """.stripMargin)
      } else {
        ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
      }
    }
  }
}

object BindReferences extends Logging {

  def bindReference[A <: Expression](
      expression: A,
      input: AttributeSeq,
      allowFailures: Boolean = false): A = {
    expression.transform { case a: AttributeReference =>
      attachTree(a, "Binding attribute") {
        val ordinal = input.indexOf(a.exprId)
        if (ordinal == -1) {
          if (allowFailures) {
            a
          } else {
            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
          }
        } else {
          BoundReference(ordinal, a.dataType, input(ordinal).nullable)
        }
      }
    }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
  }
} 
Example 40
Source File: decimalExpressions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.types._


case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val tmp = ctx.freshName("tmp")
      s"""
         | Decimal $tmp = $eval.clone();
         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
         |   ${ev.value} = $tmp;
         | } else {
         |   ${ev.isNull} = true;
         | }
       """.stripMargin
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql
} 
Example 41
Source File: ScalaUDFSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil,
      true :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvaluationWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 42
Source File: ExpressionEvalHelperSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      code"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 43
Source File: subquery.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 44
Source File: ColumnarSubquery.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package com.intel.sparkColumnarPlugin.expression

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.BaseSubqueryExec
import org.apache.spark.sql.execution.ExecSubqueryExpression
import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

class ColumnarScalarSubquery(
  query: ScalarSubquery)
  extends Expression with ColumnarExpression {

  override def dataType: DataType = query.dataType
  override def children: Seq[Expression] = Nil
  override def nullable: Boolean = true
  override def toString: String = query.toString
  override def eval(input: InternalRow): Any = query.eval(input)
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = query.doGenCode(ctx, ev)
  override def canEqual(that: Any): Boolean = query.canEqual(that)
  override def productArity: Int = query.productArity
  override def productElement(n: Int): Any = query.productElement(n)
  override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
    val value = query.eval(null)
    val resultType = CodeGeneration.getResultType(query.dataType)
    query.dataType match {
      case t: StringType =>
        (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
      case t: IntegerType =>
        (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType)
      case t: LongType =>
        (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType)
      case t: DoubleType =>
        (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
      case d: DecimalType =>
        val v = value.asInstanceOf[Decimal]
        (TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType)
      case d: DateType =>
        throw new UnsupportedOperationException(s"DateType is not supported yet.")
    }
  }
}