package com.esri.gdb import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} /** */ case class GDBRDD(@transient sc: SparkContext, gdbPath: String, gdbName: String, numPartitions: Int) extends RDD[Row](sc, Nil) with Logging { @DeveloperApi override def compute(partition: Partition, context: TaskContext): Iterator[Row] = { val part = partition.asInstanceOf[GDBPartition] val hadoopConf = if (sc == null) new Configuration() else sc.hadoopConfiguration val index = GDBIndex(gdbPath, part.hexName, hadoopConf) val table = GDBTable(gdbPath, part.hexName, hadoopConf) context.addTaskCompletionListener(context => { table.close() index.close() }) table.rowIterator(index, part.startAtRow, part.numRowsToRead) } override protected def getPartitions: Array[Partition] = { val hadoopConf = if (sc == null) new Configuration() else sc.hadoopConfiguration GDBTable.findTable(gdbPath, gdbName, hadoopConf) match { case Some(catTab) => { val index = GDBIndex(gdbPath, catTab.hexName, hadoopConf) try { val numRows = index.numRows val numRowsPerPartition = (numRows.toDouble / numPartitions).ceil.toInt var startAtRow = 0 (0 until numPartitions).map(i => { val endAtRow = startAtRow + numRowsPerPartition val numRowsToRead = if (endAtRow <= numRows) numRowsPerPartition else numRows - startAtRow val gdbPartition = GDBPartition(i, catTab.hexName, startAtRow, numRowsToRead) startAtRow += numRowsToRead gdbPartition }).toArray } finally { index.close() } } case _ => { log.error(s"Cannot find '$gdbName' in $gdbPath, creating an empty array of Partitions !") Array.empty[Partition] } } } } private[this] case class GDBPartition(m_index: Int, val hexName: String, val startAtRow: Int, val numRowsToRead: Int ) extends Partition { override def index = m_index }