org.apache.spark.sql.catalyst.expressions.UnsafeArrayData Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.UnsafeArrayData. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: MatrixUDT.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
    // set as not nullable, except values since in the future, support for binary matrices might
    // be added for which values are not needed.
    // the sparse matrix needs colPtrs and rowIndices, which are set as
    // null, while building the dense matrix.
      StructField("type", ByteType, nullable = false),
      StructField("numRows", IntegerType, nullable = false),
      StructField("numCols", IntegerType, nullable = false),
      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
      StructField("isTransposed", BooleanType, nullable = false)

  override def serialize(obj: Matrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  override def deserialize(datum: Any): Matrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

  override def userClass: Class[Matrix] = classOf[Matrix]

  override def equals(o: Any): Boolean = {
    o match {
      case v: MatrixUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

  override def typeName: String = "matrix"

  override def pyUDT: String = ""

  private[spark] override def asNullable: MatrixUDT = this
Example 2
Source File: ShapeUtils.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.util

import org.apache.spark.sql.simba.{ShapeSerializer, ShapeType}
import org.apache.spark.sql.simba.expression.PointWrapper
import org.apache.spark.sql.simba.spatial.{Point, Shape}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan

object ShapeUtils {
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: SparkPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
    } else {
      Point(, plan.output).eval(row)
  def getPointFromRow(row: InternalRow, columns: List[Attribute], plan: LogicalPlan,
                      isPoint: Boolean): Point = {
    if (isPoint) {
      ShapeSerializer.deserialize(BindReferences.bindReference(columns.head, plan.output)
    } else {
      Point(, plan.output).eval(row)

  def getShape(expression: Expression, input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
    } else if (expression.isInstanceOf[PointWrapper]) {
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")

  def getShape(expression: Expression, schema: Seq[Attribute], input: InternalRow): Shape = {
    if (!expression.isInstanceOf[PointWrapper] && expression.dataType.isInstanceOf[ShapeType]) {
      ShapeSerializer.deserialize(BindReferences.bindReference(expression, schema)
    } else if (expression.isInstanceOf[PointWrapper]) {
      BindReferences.bindReference(expression, schema).eval(input).asInstanceOf[Shape]
    } else throw new UnsupportedOperationException("Query shape should be of ShapeType")

Example 3
Source File: VectorUDT.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

  * User-defined type for [[Vector]] which allows easy interaction with SQL
  * via [[org.apache.spark.sql.Dataset]].
class VectorUDT extends UserDefinedType[Vector] {

  override final def sqlType: StructType = {
    // type: 0 = int_sparse, 1 = dense, 2 = long_sparse
    // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
    // vectors. The "values" field is nullable because we might want to add binary vectors later,
    // which uses "size" and "indices", but not "values".
      StructField("type", ByteType, nullable = false),
      StructField("size", LongType, nullable = true),
      StructField("intIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("longIndices", ArrayType(LongType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))

  override def serialize(obj: Vector): InternalRow = {
    obj match {
      case IntSparseVector(size, indices, values) =>
        val row = new GenericInternalRow(5)
        row.setByte(0, 0)
        row.setLong(1, size)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(values))
      case DenseVector(values) =>
        val row = new GenericInternalRow(5)
        row.setByte(0, 1)
        row.update(4, UnsafeArrayData.fromPrimitiveArray(values))
      case LongSparseVector(size, indices, values) =>
        val row = new GenericInternalRow(5)
        row.setByte(0, 2)
        row.setLong(1, size)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 5,
          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
            val size = row.getLong(1)
            val indices = row.getArray(2).toIntArray()
            val values = row.getArray(4).toDoubleArray()
            new IntSparseVector(size, indices, values)
          case 1 =>
            val values = row.getArray(4).toDoubleArray()
            new DenseVector(values)
          case 2 =>
            val size = row.getLong(1)
            val indices = row.getArray(3).toLongArray()
            val values = row.getArray(4).toDoubleArray()
            new LongSparseVector(size, indices, values)

  override def pyUDT: String = ""

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case v: VectorUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

  override def typeName: String = "vector"

  override def asNullable: VectorUDT = this
Example 4
Source File: MatrixUDT.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.linalg

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

 * User-defined type for [[Matrix]] which allows easy interaction with SQL
 * via [[org.apache.spark.sql.Dataset]].
class MatrixUDT extends UserDefinedType[Matrix] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
    // set as not nullable, except values since in the future, support for binary matrices might
    // be added for which values are not needed.
    // the sparse matrix needs colPtrs and rowIndices, which are set as
    // null, while building the dense matrix.
      StructField("type", ByteType, nullable = false),
      StructField("numRows", IntegerType, nullable = false),
      StructField("numCols", IntegerType, nullable = false),
      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
      StructField("isTransposed", BooleanType, nullable = false)

  override def serialize(obj: Matrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  override def deserialize(datum: Any): Matrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

  override def userClass: Class[Matrix] = classOf[Matrix]

  override def equals(o: Any): Boolean = {
    o match {
      case v: MatrixUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

  override def typeName: String = "matrix"

  override def pyUDT: String = ""

  override def asNullable: MatrixUDT = this
Example 5
Source File: CatalystTypeConvertersSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(

  test("null handling in rows") {
    val schema = StructType( => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)

  test("primitive array handling") {
    val intArray = Array(1, 100, 10000)
    val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
    val intArrayType = ArrayType(IntegerType, false)
    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)

    val doubleArray = Array(1.1, 111.1, 11111.1)
    val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, false)
      === doubleArray)

  test("An array with null handling") {
    val intArray = Array(1, null, 100, null, 10000)
    val intGenericArray = new GenericArrayData(intArray)
    val intArrayType = ArrayType(IntegerType, true)
      === intArray)
      == intGenericArray)

    val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
    val doubleGenericArray = new GenericArrayData(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, true)
      === doubleArray)
      == doubleGenericArray)
Example 6
Source File: VectorUDT.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class VectorUDT extends UserDefinedType[Vector] {

  override final def sqlType: StructType = _sqlType

  override def serialize(obj: Vector): InternalRow = {
    obj match {
      case SparseVector(size, indices, values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 0)
        row.setInt(1, size)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
      case DenseVector(values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 1)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 4,
          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
            val size = row.getInt(1)
            val indices = row.getArray(2).toIntArray()
            val values = row.getArray(3).toDoubleArray()
            new SparseVector(size, indices, values)
          case 1 =>
            val values = row.getArray(3).toDoubleArray()
            new DenseVector(values)

  override def pyUDT: String = ""

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case v: VectorUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

  override def typeName: String = "vector"

  private[spark] override def asNullable: VectorUDT = this

  private[this] val _sqlType = {
    // type: 0 = sparse, 1 = dense
    // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
    // vectors. The "values" field is nullable because we might want to add binary vectors later,
    // which uses "size" and "indices", but not "values".
      StructField("type", ByteType, nullable = false),
      StructField("size", IntegerType, nullable = true),
      StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
Example 7
Source File: MatrixUDT.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
    // set as not nullable, except values since in the future, support for binary matrices might
    // be added for which values are not needed.
    // the sparse matrix needs colPtrs and rowIndices, which are set as
    // null, while building the dense matrix.
      StructField("type", ByteType, nullable = false),
      StructField("numRows", IntegerType, nullable = false),
      StructField("numCols", IntegerType, nullable = false),
      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
      StructField("isTransposed", BooleanType, nullable = false)

  override def serialize(obj: Matrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  override def deserialize(datum: Any): Matrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

  override def userClass: Class[Matrix] = classOf[Matrix]

  override def equals(o: Any): Boolean = {
    o match {
      case v: MatrixUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

  override def typeName: String = "matrix"

  override def pyUDT: String = ""

  private[spark] override def asNullable: MatrixUDT = this
Example 8
Source File: UdtEncodedClass.scala    From frameless   with Apache License 2.0 5 votes vote down vote up
package frameless

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.sql.FramelessInternals.UserDefinedType

@SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt])
class UdtEncodedClass(val a: Int, val b: Array[Double]) {
  override def equals(other: Any): Boolean = other match {
    case that: UdtEncodedClass => a == that.a && java.util.Arrays.equals(b, that.b)
    case _ => false

  override def hashCode(): Int = {
    val state = Seq[Any](a, b), b) => 31 * a + b)

  override def toString = s"UdtEncodedClass($a, $b)"

object UdtEncodedClass {
  implicit val udtForUdtEncodedClass = new UdtEncodedClassUdt

class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] {
  def sqlType: DataType = {
      StructField("a", IntegerType, nullable = false),
      StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false)

  def serialize(obj: UdtEncodedClass): InternalRow = {
    val row = new GenericInternalRow(3)
    row.setInt(0, obj.a)
    row.update(1, UnsafeArrayData.fromPrimitiveArray(obj.b))

  def deserialize(datum: Any): UdtEncodedClass = datum match {
    case row: InternalRow => new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray())

  def userClass: Class[UdtEncodedClass] = classOf[UdtEncodedClass]
Example 9
Source File: CatalystTypeConvertersSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(

  test("null handling in rows") {
    val schema = StructType( => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)

  test("primitive array handling") {
    val intArray = Array(1, 100, 10000)
    val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
    val intArrayType = ArrayType(IntegerType, false)
    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)

    val doubleArray = Array(1.1, 111.1, 11111.1)
    val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, false)
      === doubleArray)

  test("An array with null handling") {
    val intArray = Array(1, null, 100, null, 10000)
    val intGenericArray = new GenericArrayData(intArray)
    val intArrayType = ArrayType(IntegerType, true)
      === intArray)
      == intGenericArray)

    val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
    val doubleGenericArray = new GenericArrayData(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, true)
      === doubleArray)
      == doubleGenericArray)
Example 10
Source File: VectorUDT.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class VectorUDT extends UserDefinedType[Vector] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
    // vectors. The "values" field is nullable because we might want to add binary vectors later,
    // which uses "size" and "indices", but not "values".
      StructField("type", ByteType, nullable = false),
      StructField("size", IntegerType, nullable = true),
      StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))

  override def serialize(obj: Vector): InternalRow = {
    obj match {
      case SparseVector(size, indices, values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 0)
        row.setInt(1, size)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
      case DenseVector(values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 1)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 4,
          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
            val size = row.getInt(1)
            val indices = row.getArray(2).toIntArray()
            val values = row.getArray(3).toDoubleArray()
            new SparseVector(size, indices, values)
          case 1 =>
            val values = row.getArray(3).toDoubleArray()
            new DenseVector(values)

  override def pyUDT: String = ""

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case v: VectorUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

  override def typeName: String = "vector"

  private[spark] override def asNullable: VectorUDT = this
Example 11
Source File: MatrixUDT.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
    // set as not nullable, except values since in the future, support for binary matrices might
    // be added for which values are not needed.
    // the sparse matrix needs colPtrs and rowIndices, which are set as
    // null, while building the dense matrix.
      StructField("type", ByteType, nullable = false),
      StructField("numRows", IntegerType, nullable = false),
      StructField("numCols", IntegerType, nullable = false),
      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
      StructField("isTransposed", BooleanType, nullable = false)

  override def serialize(obj: Matrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  override def deserialize(datum: Any): Matrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

  override def userClass: Class[Matrix] = classOf[Matrix]

  override def equals(o: Any): Boolean = {
    o match {
      case v: MatrixUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

  override def typeName: String = "matrix"

  override def pyUDT: String = ""

  private[spark] override def asNullable: MatrixUDT = this
Example 12
Source File: VectorUDT.scala    From ann4s   with Apache License 2.0 5 votes vote down vote up

import ann4s.{Vector0, Vector16, Vector8, Vector32, Vector64, Vector}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

class VectorUDT extends UserDefinedType[Vector] {

  override def sqlType: DataType = _sqlType

  override def serialize(obj: Vector): InternalRow = {
    val row = new GenericInternalRow(5)
    obj match {
      case Vector0 =>
        row.setByte(0, 0)
      case Vector8(values, w, b) =>
        row.setByte(0, 1)
        row.update(1, UnsafeArrayData.fromPrimitiveArray(values))
        row.update(3, UnsafeArrayData.fromPrimitiveArray(Array(w, b)))
      case Vector16(values) =>
        row.setByte(0, 2)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(values))
      case Vector32(values) =>
        row.setByte(0, 3)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
      case Vector64(values) =>
        row.setByte(0, 4)
        row.update(4, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 5,
          s"nn.VectorUDT.deserialize given row with length ${row.numFields} but requires length == 5")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
          case 1 =>
            val wb = row.getArray(3).toFloatArray()
            Vector8(row.getArray(1).toByteArray(), wb(0), wb(1))
          case 2 =>
          case 3 =>
          case 4 =>

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case _: VectorUDT => true
      case _ => false

  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode

  override def typeName: String = "nn.vector"

  private[spark] override def asNullable: VectorUDT = this

  private[this] val _sqlType = {
      StructField("type", ByteType, nullable = false),
      StructField("fixed8", ArrayType(ByteType, containsNull = false), nullable = true),
      StructField("fixed16", ArrayType(ShortType, containsNull = false), nullable = true),
      StructField("float32", ArrayType(FloatType, containsNull = false), nullable = true),
      StructField("float64", ArrayType(DoubleType, containsNull = false), nullable = true)))

object VectorUDT {

  def register(): Unit = {
    UDTRegistration.register("ann4s.Vector", "")
    UDTRegistration.register("ann4s.EmptyVector", "")
    UDTRegistration.register("ann4s.Fixed8Vector", "")
    UDTRegistration.register("ann4s.Fixed16Vector", "")
    UDTRegistration.register("ann4s.Float32Vector", "")
    UDTRegistration.register("ann4s.Float64Vector", "")

Example 13
Source File: CatalystTypeConvertersSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {

  private val simpleTypes: Seq[DataType] = Seq(

  test("null handling in rows") {
    val schema = StructType( => StructField(t.getClass.getName, t)))
    val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
    val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)

    val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
    assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)

  test("null handling for individual values") {
    for (dataType <- simpleTypes) {
      assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)

  test("option handling in convertToCatalyst") {
    // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
    // createToCatalystConverter but it may not actually matter as this is only called internally
    // in a handful of places where we don't expect to receive Options.
    assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))

  test("option handling in createToCatalystConverter") {
    assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)

  test("primitive array handling") {
    val intArray = Array(1, 100, 10000)
    val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
    val intArrayType = ArrayType(IntegerType, false)
    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)

    val doubleArray = Array(1.1, 111.1, 11111.1)
    val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, false)
      === doubleArray)

  test("An array with null handling") {
    val intArray = Array(1, null, 100, null, 10000)
    val intGenericArray = new GenericArrayData(intArray)
    val intArrayType = ArrayType(IntegerType, true)
      === intArray)
      == intGenericArray)

    val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
    val doubleGenericArray = new GenericArrayData(doubleArray)
    val doubleArrayType = ArrayType(DoubleType, true)
      === doubleArray)
      == doubleGenericArray)
Example 14
Source File: VectorUDT.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class VectorUDT extends UserDefinedType[Vector] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
    // vectors. The "values" field is nullable because we might want to add binary vectors later,
    // which uses "size" and "indices", but not "values".
      StructField("type", ByteType, nullable = false),
      StructField("size", IntegerType, nullable = true),
      StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))

  override def serialize(obj: Vector): InternalRow = {
    obj match {
      case SparseVector(size, indices, values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 0)
        row.setInt(1, size)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
      case DenseVector(values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 1)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 4,
          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
            val size = row.getInt(1)
            val indices = row.getArray(2).toIntArray()
            val values = row.getArray(3).toDoubleArray()
            new SparseVector(size, indices, values)
          case 1 =>
            val values = row.getArray(3).toDoubleArray()
            new DenseVector(values)

  override def pyUDT: String = ""

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case v: VectorUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

  override def typeName: String = "vector"

  private[spark] override def asNullable: VectorUDT = this
Example 15
Source File: MatrixUDT.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
    // set as not nullable, except values since in the future, support for binary matrices might
    // be added for which values are not needed.
    // the sparse matrix needs colPtrs and rowIndices, which are set as
    // null, while building the dense matrix.
      StructField("type", ByteType, nullable = false),
      StructField("numRows", IntegerType, nullable = false),
      StructField("numCols", IntegerType, nullable = false),
      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
      StructField("isTransposed", BooleanType, nullable = false)

  override def serialize(obj: Matrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  override def deserialize(datum: Any): Matrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

  override def userClass: Class[Matrix] = classOf[Matrix]

  override def equals(o: Any): Boolean = {
    o match {
      case v: MatrixUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()

  override def typeName: String = "matrix"

  override def pyUDT: String = ""

  private[spark] override def asNullable: MatrixUDT = this
Example 16
Source File: ArrayDataIndexedSeqSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection}
import org.apache.spark.sql.types._

class ArrayDataIndexedSeqSuite extends SparkFunSuite {
  private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = {
    assert(arrayData.numElements == array.length) { case (e, i) =>
      if (e != null) {
        elementDt match {
          // For NaN, etc.
          case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e))
          case _ => assert(arrayData.get(i, elementDt) === e)
      } else {

    val seq = arrayData.toSeq[Any](elementDt) { case (e, i) =>
      if (e != null) {
        elementDt match {
          // For Nan, etc.
          case FloatType | DoubleType => assert(seq(i).equals(e))
          case _ => assert(seq(i) === e)
      } else {
        assert(seq(i) == null)

    intercept[IndexOutOfBoundsException] {
    }.getMessage().contains("must be between 0 and the length of the ArrayData.")

    intercept[IndexOutOfBoundsException] {
    }.getMessage().contains("must be between 0 and the length of the ArrayData.")

  private def testArrayData(): Unit = {
    val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
      DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
      CalendarIntervalType, new ExamplePointUDT())
    val arrayTypes = elementTypes.flatMap { elementType =>
      Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
    val random = new Random(100)
    arrayTypes.foreach { dt =>
      val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil)
      val row = RandomDataGenerator.randomRow(random, schema)
      val rowConverter = RowEncoder(schema)
      val internalRow = rowConverter.toRow(row)

      val unsafeRowConverter = UnsafeProjection.create(schema)
      val safeRowConverter = FromUnsafeProjection(schema)

      val unsafeRow = unsafeRowConverter(internalRow)
      val safeRow = safeRowConverter(unsafeRow)

      val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
      val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]

      val elementType = dt.elementType
      test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
        compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType))

      test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
        compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType))

Example 17
Source File: ArrayData.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods

object ArrayData {
  def toArrayData(input: Any): ArrayData = input match {
    case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
    case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
    case other => new GenericArrayData(other)

class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] {

  private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType)

  override def apply(idx: Int): T =
    if (0 <= idx && idx < arrayData.numElements()) {
      if (arrayData.isNullAt(idx)) {
      } else {
        accessor(arrayData, idx).asInstanceOf[T]
    } else {
      throw new IndexOutOfBoundsException(
        s"Index $idx must be between 0 and the length of the ArrayData.")

  override def length: Int = arrayData.numElements()
Example 18
Source File: MLMatrixSerializer.scala    From MatRel   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.matfast.util

import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.matfast.matrix._

object MLMatrixSerializer {

  def serialize(obj: MLMatrix): InternalRow = {
    val row = new GenericInternalRow(7)
    obj match {
      case sm: SparseMatrix =>
        row.setByte(0, 0)
        row.setInt(1, sm.numRows)
        row.setInt(2, sm.numCols)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
        row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
        row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
        row.setBoolean(6, sm.isTransposed)

      case dm: DenseMatrix =>
        row.setByte(0, 1)
        row.setInt(1, dm.numRows)
        row.setInt(2, dm.numCols)
        row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
        row.setBoolean(6, dm.isTransposed)

  def deserialize(datum: Any): MLMatrix = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 7,
          s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7")
        val tpe = row.getByte(0)
        val numRows = row.getInt(1)
        val numCols = row.getInt(2)
        val values = row.getArray(5).toDoubleArray()
        val isTransposed = row.getBoolean(6)
        tpe match {
          case 0 =>
            val colPtrs = row.getArray(3).toIntArray()
            val rowIndices = row.getArray(4).toIntArray()
            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
          case 1 =>
            new DenseMatrix(numRows, numCols, values, isTransposed)

class MLMatrixSerializer {

Example 19
Source File: VectorUDT.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._

private[spark] class VectorUDT extends UserDefinedType[Vector] {

  override def sqlType: StructType = {
    // type: 0 = sparse, 1 = dense
    // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
    // vectors. The "values" field is nullable because we might want to add binary vectors later,
    // which uses "size" and "indices", but not "values".
      StructField("type", ByteType, nullable = false),
      StructField("size", IntegerType, nullable = true),
      StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))

  override def serialize(obj: Vector): InternalRow = {
    obj match {
      case SparseVector(size, indices, values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 0)
        row.setInt(1, size)
        row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
      case DenseVector(values) =>
        val row = new GenericInternalRow(4)
        row.setByte(0, 1)
        row.update(3, UnsafeArrayData.fromPrimitiveArray(values))

  override def deserialize(datum: Any): Vector = {
    datum match {
      case row: InternalRow =>
        require(row.numFields == 4,
          s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4")
        val tpe = row.getByte(0)
        tpe match {
          case 0 =>
            val size = row.getInt(1)
            val indices = row.getArray(2).toIntArray()
            val values = row.getArray(3).toDoubleArray()
            new SparseVector(size, indices, values)
          case 1 =>
            val values = row.getArray(3).toDoubleArray()
            new DenseVector(values)

  override def pyUDT: String = ""

  override def userClass: Class[Vector] = classOf[Vector]

  override def equals(o: Any): Boolean = {
    o match {
      case v: VectorUDT => true
      case _ => false

  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()

  override def typeName: String = "vector"

  private[spark] override def asNullable: VectorUDT = this