org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: GroupedIterator.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, GenerateUnsafeProjection}

object GroupedIterator {
  def apply(
      input: Iterator[InternalRow],
      keyExpressions: Seq[Expression],
      inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
    if (input.hasNext) {
      new GroupedIterator(input.buffered, keyExpressions, inputSchema)
    } else {
      Iterator.empty
    }
  }
}


  def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator

  def next(): (InternalRow, Iterator[InternalRow]) = {
    assert(hasNext) // Ensure we have fetched the next iterator.
    val ret = (keyProjection(currentGroup), currentIterator)
    currentIterator = null
    ret
  }

  private def fetchNextGroupIterator(): Boolean = {
    assert(currentIterator == null)

    if (currentRow == null && input.hasNext) {
      currentRow = input.next()
    }

    if (currentRow == null) {
      // These is no data left, return false.
      false
    } else {
      // Skip to next group.
      // currentRow may be overwritten by `hasNext`, so we should compare them first.
      while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) {
        currentRow = input.next()
      }

      if (keyOrdering.compare(currentGroup, currentRow) == 0) {
        // We are in the last group, there is no more groups, return false.
        false
      } else {
        // Now the `currentRow` is the first row of next group.
        currentGroup = currentRow.copy()
        currentIterator = createGroupValuesIterator()
        true
      }
    }
  }

  private def createGroupValuesIterator(): Iterator[InternalRow] = {
    new Iterator[InternalRow] {
      def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()

      def next(): InternalRow = {
        assert(hasNext)
        val res = currentRow
        currentRow = null
        res
      }

      private def fetchNextRowInGroup(): Boolean = {
        assert(currentRow == null)

        if (input.hasNext) {
          // The inner iterator should NOT consume the input into next group, here we use `head` to
          // peek the next input, to see if we should continue to process it.
          if (keyOrdering.compare(currentGroup, input.head) == 0) {
            // Next input is in the current group.  Continue the inner iterator.
            currentRow = input.next()
            true
          } else {
            // Next input is not in the right group.  End this inner iterator.
            false
          }
        } else {
          // There is no more data, return false.
          false
        }
      }
    }
  }
} 
Example 2
Source File: GroupedIterator.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, GenerateUnsafeProjection}

object GroupedIterator {
  def apply(
      input: Iterator[InternalRow],
      keyExpressions: Seq[Expression],
      inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
    if (input.hasNext) {
      new GroupedIterator(input.buffered, keyExpressions, inputSchema)
    } else {
      Iterator.empty
    }
  }
}


  def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator

  def next(): (InternalRow, Iterator[InternalRow]) = {
    assert(hasNext) // Ensure we have fetched the next iterator.
    val ret = (keyProjection(currentGroup), currentIterator)
    currentIterator = null
    ret
  }

  private def fetchNextGroupIterator(): Boolean = {
    assert(currentIterator == null)

    if (currentRow == null && input.hasNext) {
      currentRow = input.next()
    }

    if (currentRow == null) {
      // These is no data left, return false.
      false
    } else {
      // Skip to next group.
      // currentRow may be overwritten by `hasNext`, so we should compare them first.
      while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) {
        currentRow = input.next()
      }

      if (keyOrdering.compare(currentGroup, currentRow) == 0) {
        // We are in the last group, there is no more groups, return false.
        false
      } else {
        // Now the `currentRow` is the first row of next group.
        currentGroup = currentRow.copy()
        currentIterator = createGroupValuesIterator()
        true
      }
    }
  }

  private def createGroupValuesIterator(): Iterator[InternalRow] = {
    new Iterator[InternalRow] {
      def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()

      def next(): InternalRow = {
        assert(hasNext)
        val res = currentRow
        currentRow = null
        res
      }

      private def fetchNextRowInGroup(): Boolean = {
        assert(currentRow == null)

        if (input.hasNext) {
          // The inner iterator should NOT consume the input into next group, here we use `head` to
          // peek the next input, to see if we should continue to process it.
          if (keyOrdering.compare(currentGroup, input.head) == 0) {
            // Next input is in the current group.  Continue the inner iterator.
            currentRow = input.next()
            true
          } else {
            // Next input is not in the right group.  End this inner iterator.
            false
          }
        } else {
          // There is no more data, return false.
          false
        }
      }
    }
  }
} 
Example 3
Source File: DeltaInvariantCheckerExec.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta.schema

import org.apache.spark.sql.delta.DeltaErrors
import org.apache.spark.sql.delta.schema.Invariants.NotNull

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, Expression, GetStructField, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.{NullType, StructType}


  private def buildExtractors(invariant: Invariant): Option[Expression] = {
    assert(invariant.column.nonEmpty)
    val topLevelColumn = invariant.column.head
    val topLevelRefOpt = output.collectFirst {
      case a: AttributeReference if SchemaUtils.DELTA_COL_RESOLVER(a.name, topLevelColumn) => a
    }
    val rejectColumnNotFound = isNullNotOkay(invariant)
    if (topLevelRefOpt.isEmpty) {
      if (rejectColumnNotFound) {
        throw DeltaErrors.notNullInvariantException(invariant)
      }
    }

    if (invariant.column.length == 1) {
      topLevelRefOpt.map(BindReferences.bindReference[Expression](_, output))
    } else {
      topLevelRefOpt.flatMap { topLevelRef =>
        val boundTopLevel = BindReferences.bindReference[Expression](topLevelRef, output)
        try {
          val nested = invariant.column.tail.foldLeft(boundTopLevel) { case (e, fieldName) =>
            e.dataType match {
              case StructType(fields) =>
                val ordinal = fields.indexWhere(f =>
                  SchemaUtils.DELTA_COL_RESOLVER(f.name, fieldName))
                if (ordinal == -1) {
                  throw new IndexOutOfBoundsException(s"Not nullable column not found in struct: " +
                      s"${fields.map(_.name).mkString("[", ",", "]")}")
                }
                GetStructField(e, ordinal, Some(fieldName))
              case _ =>
                throw new UnsupportedOperationException(
                  "Invariants on nested fields other than StructTypes are not supported.")
            }
          }
          Some(nested)
        } catch {
          case i: IndexOutOfBoundsException if rejectColumnNotFound =>
            throw InvariantViolationException(invariant, i.getMessage)
          case _: IndexOutOfBoundsException if !rejectColumnNotFound =>
            None
        }
      }
    }
  }

  override protected def doExecute(): RDD[InternalRow] = {
    if (invariants.isEmpty) return child.execute()
    val boundRefs = invariants.map { invariant =>
      CheckDeltaInvariant(buildExtractors(invariant).getOrElse(Literal(null, NullType)), invariant)
    }

    child.execute().mapPartitionsInternal { rows =>
      val assertions = GenerateUnsafeProjection.generate(boundRefs)
      rows.map { row =>
        assertions(row)
        row
      }
    }
  }

  override def outputOrdering: Seq[SortOrder] = child.outputOrdering

  override def outputPartitioning: Partitioning = child.outputPartitioning
} 
Example 4
Source File: ComplexDataSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {
  def utf8(str: String): UTF8String = UTF8String.fromString(str)

  test("inequality tests for MapData") {
    // test data
    val testMap1 = Map(utf8("key1") -> 1)
    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
    val testMap3 = Map(utf8("key1") -> 1)
    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }

  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericRow = genericRow.copy()
    assert(copiedGenericRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedGenericRow.getString(0) == "a")
  }

  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val mutableRow = new SpecificInternalRow(Seq(StringType))
    mutableRow(0) = unsafeRow.getUTF8String(0)
    val copiedMutableRow = mutableRow.copy()
    assert(copiedMutableRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedMutableRow.getString(0) == "a")
  }

  test("GenericArrayData.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericArray = genericArray.copy()
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied array data should not be changed externally.
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
  }

  test("copy on nested complex type") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
    val copied = arrayOfRow.copy()
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied data should not be changed externally.
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
  }
}