org.apache.spark.sql.internal.SQLConf Scala Examples

The following examples show how to use org.apache.spark.sql.internal.SQLConf. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkPlanner.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan
    }
  }

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.
    plans
  }

  
  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =
      prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]])
      filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList, filterCondition.map(FilterExec(_, scan)).getOrElse(scan))
    }
  }
} 
Example 2
Source File: CartesianProductExec.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter


class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
  extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {

  override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
    // We will not sort the rows, so prefixComparator and recordComparator are null.
    val sorter = UnsafeExternalSorter.create(
      context.taskMemoryManager(),
      SparkEnv.get.blockManager,
      SparkEnv.get.serializerManager,
      context,
      null,
      null,
      1024,
      SparkEnv.get.memoryManager.pageSizeBytes,
      SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
      false)

    val partition = split.asInstanceOf[CartesianPartition]
    for (y <- rdd2.iterator(partition.s2, context)) {
      sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
    }

    // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
    def createIter(): Iterator[UnsafeRow] = {
      val iter = sorter.getIterator
      val unsafeRow = new UnsafeRow(numFieldsOfRight)
      new Iterator[UnsafeRow] {
        override def hasNext: Boolean = {
          iter.hasNext
        }
        override def next(): UnsafeRow = {
          iter.loadNext()
          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
          unsafeRow
        }
      }
    }

    val resultIter =
      for (x <- rdd1.iterator(partition.s1, context);
           y <- createIter()) yield (x, y)
    CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
      resultIter, sorter.cleanupResources())
  }
}


case class CartesianProductExec(
    left: SparkPlan,
    right: SparkPlan,
    condition: Option[Expression]) extends BinaryExecNode {
  override def output: Seq[Attribute] = left.output ++ right.output

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")

    val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
    val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]

    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
    pair.mapPartitionsInternal { iter =>
      val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
      val filtered = if (condition.isDefined) {
        val boundCondition: (InternalRow) => Boolean =
          newPredicate(condition.get, left.output ++ right.output)
        val joined = new JoinedRow

        iter.filter { r =>
          boundCondition(joined(r._1, r._2))
        }
      } else {
        iter
      }
      filtered.map { r =>
        numOutputRows += 1
        joiner.join(r._1, r._2)
      }
    }
  }
} 
Example 3
Source File: ParquetOptions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.internal.SQLConf


  val mergeSchema: Boolean = parameters
    .get(MERGE_SCHEMA)
    .map(_.toBoolean)
    .getOrElse(sqlConf.isParquetSchemaMergingEnabled)
}


object ParquetOptions {
  val MERGE_SCHEMA = "mergeSchema"

  // The parquet compression short names
  private val shortParquetCompressionCodecNames = Map(
    "none" -> CompressionCodecName.UNCOMPRESSED,
    "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
    "snappy" -> CompressionCodecName.SNAPPY,
    "gzip" -> CompressionCodecName.GZIP,
    "lzo" -> CompressionCodecName.LZO)
} 
Example 4
Source File: Exchange.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 5
Source File: subquery.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 6
Source File: FileStreamSinkLog.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
import org.json4s.jackson.Serialization.{read, write}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: String,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val compactInterval = sparkSession.sessionState.conf.fileSinkLogCompactInterval
  require(compactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " +
      "to a positive value.")

  protected override def serializeData(data: SinkFileStatus): String = {
    write(data)
  }

  protected override def deserializeData(encodedString: String): SinkFileStatus = {
    read[SinkFileStatus](encodedString)
  }

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = "v1"
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 7
Source File: SparkOptimizer.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
    catalog: SessionCatalog,
    conf: SQLConf,
    experimentalMethods: ExperimentalMethods)
  extends Optimizer(catalog, conf) {

  override def batches: Seq[Batch] = super.batches :+
    Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
    Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
    Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
    Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
} 
Example 8
Source File: TestSparkSession.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.internal.{SQLConf, SessionState, SessionStateBuilder, WithTestConf}
import org.apache.spark.{SparkConf, SparkContext}


  val overrideConfs: Map[String, String] =
    Map(
      // Fewer shuffle partitions to speed up testing.
      SQLConf.SHUFFLE_PARTITIONS.key -> "3"
    )
}

private[sql] class TestSQLSessionStateBuilder(
  session: SparkSession,
  state: Option[SessionState])
  extends SessionStateBuilder(session, state) with WithTestConf {
  override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs
  override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _)
} 
Example 9
Source File: QueryPartitionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.io.File
import java.sql.Timestamp

import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem

import org.apache.spark.internal.config._
import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  import spark.implicits._

  private def queryWhenPathNotExist(): Unit = {
    withTempView("testData") {
      withTable("table_with_partition", "createAndInsertTest") {
        withTempDir { tmpDir =>
          val testData = sparkContext.parallelize(
            (1 to 10).map(i => TestData(i, i.toString))).toDF()
          testData.createOrReplaceTempView("testData")

          // create the table for test
          sql(s"CREATE TABLE table_with_partition(key int,value string) " +
              s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
              "SELECT key,value FROM testData")
          sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
              "SELECT key,value FROM testData")

          // test for the exist path
          checkAnswer(sql("select key,value from table_with_partition"),
            testData.union(testData).union(testData).union(testData))

          // delete the path of one partition
          tmpDir.listFiles
              .find { f => f.isDirectory && f.getName().startsWith("ds=") }
              .foreach { f => Utils.deleteRecursively(f) }

          // test for after delete the path
          checkAnswer(sql("select key,value from table_with_partition"),
            testData.union(testData).union(testData))
        }
      }
    }
  }

  test("SPARK-5068: query data when path doesn't exist") {
    withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") {
      queryWhenPathNotExist()
    }
  }

  test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") {
    withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") {
      sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true")
      queryWhenPathNotExist()
    }
  }

  test("SPARK-21739: Cast expression should initialize timezoneId") {
    withTable("table_with_timestamp_partition") {
      sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
      sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
        "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")

      // test for Cast expression in TableReader
      checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
        Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))

      // test for Cast expression in HiveTableScanExec
      checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
        "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
    }
  }
} 
Example 10
Source File: TestHiveSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils


class TestHiveSuite extends TestHiveSingleton with SQLTestUtils {
  test("load test table based on case sensitivity") {
    val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession]

    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
      sql("SELECT * FROM SRC").queryExecution.analyzed
      assert(testHiveSparkSession.getLoadedTables.contains("src"))
      assert(testHiveSparkSession.getLoadedTables.size == 1)
    }
    testHiveSparkSession.reset()

    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
      val err = intercept[AnalysisException] {
        sql("SELECT * FROM SRC").queryExecution.analyzed
      }
      assert(err.message.contains("Table or view not found"))
    }
    testHiveSparkSession.reset()
  }

  test("SPARK-15887: hive-site.xml should be loaded") {
    assert(hiveClient.getConf("hive.in.test", "") == "true")
  }
} 
Example 11
Source File: HiveParquetSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf

case class Cases(lower: String, UPPER: String)

class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton {

  test("Case insensitive attribute names") {
    withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
      val expected = (1 to 4).map(i => Row(i.toString))
      checkAnswer(sql("SELECT upper FROM cases"), expected)
      checkAnswer(sql("SELECT LOWER FROM cases"), expected)
    }
  }

  test("SELECT on Parquet table") {
    val data = (1 to 4).map(i => (i, s"val_$i"))
    withParquetTable(data, "t") {
      checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple))
    }
  }

  test("Simple column projection + filter on Parquet table") {
    withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") {
      checkAnswer(
        sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"),
        Seq(Row(true, "val_2"), Row(true, "val_4")))
    }
  }

  test("Converting Hive to Parquet Table via saveAsParquetFile") {
    withTempPath { dir =>
      sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
      spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("p")
      withTempView("p") {
        checkAnswer(
          sql("SELECT * FROM src ORDER BY key"),
          sql("SELECT * from p ORDER BY key").collect().toSeq)
      }
    }
  }

  test("INSERT OVERWRITE TABLE Parquet table") {
    // Don't run with vectorized: currently relies on UnsafeRow.
    withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) {
      withTempPath { file =>
        sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
        spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("p")
        withTempView("p") {
          // let's do three overwrites for good measure
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq)
        }
      }
    }
  }

  test("SPARK-25206: wrong records are returned by filter pushdown " +
    "when Hive metastore schema and parquet schema are in different letter cases") {
    withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) {
      withTempPath { path =>
        val data = spark.range(1, 10).toDF("id")
        data.write.parquet(path.getCanonicalPath)
        withTable("SPARK_25206") {
          sql("CREATE TABLE SPARK_25206 (ID LONG) USING parquet LOCATION " +
            s"'${path.getCanonicalPath}'")
          checkAnswer(sql("select id from SPARK_25206 where id > 0"), data)
        }
      }
    }
  }
} 
Example 12
Source File: PruneFileSourcePartitionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import org.scalatest.Matchers._

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, ResolvedHint}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil
  }

  test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
    withTable("test") {
      withTempDir { dir =>
        sql(
          s"""
            |CREATE EXTERNAL TABLE test(i int)
            |PARTITIONED BY (p int)
            |STORED AS parquet
            |LOCATION '${dir.toURI}'""".stripMargin)

        val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
        val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)

        val dataSchema = StructType(tableMeta.schema.filterNot { f =>
          tableMeta.partitionColumnNames.contains(f.name)
        })
        val relation = HadoopFsRelation(
          location = catalogFileIndex,
          partitionSchema = tableMeta.partitionSchema,
          dataSchema = dataSchema,
          bucketSpec = None,
          fileFormat = new ParquetFileFormat(),
          options = Map.empty)(sparkSession = spark)

        val logicalRelation = LogicalRelation(relation, tableMeta)
        val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze

        val optimized = Optimize.execute(query)
        assert(optimized.missingInput.isEmpty)
      }
    }
  }

  test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") {
    withTable("tbl") {
      spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
      sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS")
      val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats
      assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost")

      val df = sql("SELECT * FROM tbl WHERE p = 1")
      val sizes1 = df.queryExecution.analyzed.collect {
        case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes
      }
      assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}")
      assert(sizes1(0) == tableStats.get.sizeInBytes)

      val relations = df.queryExecution.optimizedPlan.collect {
        case relation: LogicalRelation => relation
      }
      assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
      val size2 = relations(0).stats.sizeInBytes
      assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
      assert(size2 < tableStats.get.sizeInBytes)
    }
  }

  test("SPARK-26576 Broadcast hint not applied to partitioned table") {
    withTable("tbl") {
      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
        spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
        val df = spark.table("tbl")
        val qe = df.join(broadcast(df), "p").queryExecution
        qe.optimizedPlan.collect { case _: ResolvedHint => } should have size 1
        qe.sparkPlan.collect { case j: BroadcastHashJoinExec => j } should have size 1
      }
    }
  }
} 
Example 13
Source File: FiltersSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(Collections.singletonList(varCharCol))

  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")

  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")

  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")

  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning",
    (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) ::
      (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil,
    """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""")

  filterTest("SPARK-24879 null literals should be ignored for IN constructs",
    (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil,
    "(intcol = 1)")

  // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization
  // will be applied by Catalyst, this filter converter does not need to account for this.
  filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE",
    (a("intcol", IntegerType) in Literal(null)) :: Nil,
    "")

  filterTest("typecast null literals should not be pushed down in simple predicates",
    (a("intcol", IntegerType) === Literal(null, IntegerType)) :: Nil,
    "")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name) {
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") {
        val converted = shim.convertFilters(testTable, filters)
        if (converted != result) {
          fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'")
        }
      }
    }
  }

  test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") {
    import org.apache.spark.sql.catalyst.dsl.expressions._
    Seq(true, false).foreach { enabled =>
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) {
        val filters =
          (Literal(1) === a("intcol", IntegerType) ||
            Literal(2) === a("intcol", IntegerType)) :: Nil
        val converted = shim.convertFilters(testTable, filters)
        if (enabled) {
          assert(converted == "(1 = intcol or 2 = intcol)")
        } else {
          assert(converted.isEmpty)
        }
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 14
Source File: SparkSQLOperationManager.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver.server

import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation}
import org.apache.spark.sql.internal.SQLConf


private[thriftserver] class SparkSQLOperationManager()
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]()
  val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]()

  override def newExecuteStatementOperation(
      parentSession: HiveSession,
      statement: String,
      confOverlay: JMap[String, String],
      async: Boolean): ExecuteStatementOperation = synchronized {
    val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
    require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
      s" initialized or had already closed.")
    val conf = sqlContext.sessionState.conf
    val hiveSessionState = parentSession.getSessionState
    setConfMap(conf, hiveSessionState.getOverriddenConfigurations)
    setConfMap(conf, hiveSessionState.getHiveVariables)
    val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
    val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
      runInBackground)(sqlContext, sessionToActivePool)
    handleToOperation.put(operation.getHandle, operation)
    logDebug(s"Created Operation for $statement with session=$parentSession, " +
      s"runInBackground=$runInBackground")
    operation
  }

  def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
    val iterator = confMap.entrySet().iterator()
    while (iterator.hasNext) {
      val kv = iterator.next()
      conf.setConfString(kv.getKey, kv.getValue)
    }
  }
} 
Example 15
Source File: CodeGeneratorWithInterpretedFallback.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging {

  def createObject(in: IN): OUT = {
    // We are allowed to choose codegen-only or no-codegen modes if under tests.
    val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
    val fallbackMode = CodegenObjectFactoryMode.withName(config)

    fallbackMode match {
      case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
        createCodeGeneratedObject(in)
      case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
        createInterpretedObject(in)
      case _ =>
        try {
          createCodeGeneratedObject(in)
        } catch {
          case NonFatal(_) =>
            // We should have already seen the error message in `CodeGenerator`
            logWarning("Expr codegen error and falling back to interpreter mode")
            createInterpretedObject(in)
        }
    }
  }

  protected def createCodeGeneratedObject(in: IN): OUT
  protected def createInterpretedObject(in: IN): OUT
} 
Example 16
Source File: view.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


object EliminateView extends Rule[LogicalPlan] {
  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // The child should have the same output attributes with the View operator, so we simply
    // remove the View operator.
    case View(_, output, child) =>
      assert(output == child.output,
        s"The output of the child ${child.output.mkString("[", ",", "]")} is different from the " +
          s"view output ${output.mkString("[", ",", "]")}")
      child
  }
} 
Example 17
Source File: SubstituteUnresolvedOrdinals.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 18
Source File: ResolveInlineTables.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          val castedExpr = if (e.dataType.sameType(targetType)) {
            e
          } else {
            cast(e, targetType)
          }
          castedExpr.eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 19
Source File: DateFormatterSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.util

import java.time.LocalDate

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf

class DateFormatterSuite extends SparkFunSuite with SQLHelper {
  test("parsing dates") {
    DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
        val formatter = DateFormatter()
        val daysSinceEpoch = formatter.parse("2018-12-02")
        assert(daysSinceEpoch === 17867)
      }
    }
  }

  test("format dates") {
    DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
        val formatter = DateFormatter()
        val date = formatter.format(17867)
        assert(date === "2018-12-02")
      }
    }
  }

  test("roundtrip date -> days -> date") {
    Seq(
      "0050-01-01",
      "0953-02-02",
      "1423-03-08",
      "1969-12-31",
      "1972-08-25",
      "1975-09-26",
      "2018-12-12",
      "2038-01-01",
      "5010-11-17").foreach { date =>
      DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
        withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
          val formatter = DateFormatter()
          val days = formatter.parse(date)
          val formatted = formatter.format(days)
          assert(date === formatted)
        }
      }
    }
  }

  test("roundtrip days -> date -> days") {
    Seq(
      -701265,
      -371419,
      -199722,
      -1,
      0,
      967,
      2094,
      17877,
      24837,
      1110657).foreach { days =>
      DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
        withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
          val formatter = DateFormatter()
          val date = formatter.format(days)
          val parsed = formatter.parse(date)
          assert(days === parsed)
        }
      }
    }
  }

  test("parsing date without explicit day") {
    val formatter = DateFormatter("yyyy MMM")
    val daysSinceEpoch = formatter.parse("2018 Dec")
    assert(daysSinceEpoch === LocalDate.of(2018, 12, 1).toEpochDay)
  }
} 
Example 20
Source File: StatsEstimationTestBase.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4
    case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize)
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 21
Source File: CodeGeneratorWithInterpretedFallbackSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.concurrent.ExecutionException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {

  object FailedCodegenProjection
      extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {

    override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
      val invalidCode = new CodeAndComment("invalid code", Map.empty)
      // We assume this compilation throws an exception
      CodeGenerator.compile(invalidCode)
      null
    }

    override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
      InterpretedUnsafeProjection.createProjection(in)
    }
  }

  test("UnsafeProjection with codegen factory mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
      val obj = UnsafeProjection.createObject(input)
      assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
    }

    val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
      val obj = UnsafeProjection.createObject(input)
      assert(obj.isInstanceOf[InterpretedUnsafeProjection])
    }
  }

  test("fallback to the interpreter mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val fallback = CodegenObjectFactoryMode.FALLBACK.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) {
      val obj = FailedCodegenProjection.createObject(input)
      assert(obj.isInstanceOf[InterpretedUnsafeProjection])
    }
  }

  test("codegen failures in the CODEGEN_ONLY mode") {
    val errMsg = intercept[ExecutionException] {
      val input = Seq(BoundReference(0, IntegerType, nullable = true))
      val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
        FailedCodegenProjection.createObject(input)
      }
    }.getMessage
    assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:"))
  }
} 
Example 22
Source File: AnalysisTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
    catalog.createDatabase(
      CatalogDatabase("default", "", new URI("loc"), Map.empty),
      ignoreIfExists = false)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
    catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.executeAndCheck(inputPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected override def comparePlans(
      plan1: LogicalPlan,
      plan2: LogicalPlan,
      checkAnalysis: Boolean = false): Unit = {
    // Analysis tests may have not been fully resolved, so skip checkAnalysis.
    super.comparePlans(plan1, plan2, checkAnalysis)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall(
        e.getMessage.toLowerCase(Locale.ROOT).contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 23
Source File: SubstituteUnresolvedOrdinalsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.internal.SQLConf

class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
  private lazy val a = testRelation2.output(0)
  private lazy val b = testRelation2.output(1)

  test("unresolved ordinal should not be unresolved") {
    // Expression OrderByOrdinal is unresolved.
    assert(!UnresolvedOrdinal(0).resolved)
  }

  test("order by ordinal") {
    // Tests order by ordinal, apply single rule.
    val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf).apply(plan),
      testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))

    // Tests order by ordinal, do full analysis
    checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc))

    // order by ordinal can be turned off by config
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan),
      testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
  }

  test("group by ordinal") {
    // Tests group by ordinal, apply single rule.
    val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf).apply(plan2),
      testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b))

    // Tests group by ordinal, do full analysis
    checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b))

    // group by ordinal can be turned off by config
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2),
      testRelation2.groupBy(Literal(1), Literal(2))('a, 'b))
  }
} 
Example 24
Source File: LookupFunctionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

  test("SPARK-23486: the functionExists for the Persistent function check") {
    val externalCatalog = new CustomInMemoryCatalog
    val conf = new SQLConf()
    val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
    val analyzer = {
      catalog.createDatabase(
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)
    }

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
        Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
        Alias(unresolvedRegisteredFunc, "call5")()),
      table("TaBlE"))
    analyzer.LookupFunctions.apply(plan)

    assert(externalCatalog.getFunctionExistsCalledTimes == 1)
    assert(analyzer.LookupFunctions.normalizeFuncName
      (unresolvedPersistentFunc.name).database == Some("default"))
  }

  test("SPARK-23486: the functionExists for the Registered function check") {
    val externalCatalog = new InMemoryCatalog
    val conf = new SQLConf()
    val customerFunctionReg = new CustomerFunctionRegistry
    val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
    val analyzer = {
      catalog.createDatabase(
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)
    }

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),
      table("TaBlE"))
    analyzer.LookupFunctions.apply(plan)

    assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
    assert(analyzer.LookupFunctions.normalizeFuncName
      (unresolvedRegisteredFunc.name).database == Some("default"))
  }
}

class CustomerFunctionRegistry extends SimpleFunctionRegistry {

  private var isRegisteredFunctionCalledTimes: Int = 0;

  override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
    isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1
    true
  }

  def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes
}

class CustomInMemoryCatalog extends InMemoryCatalog {

  private var functionExistsCalledTimes: Int = 0

  override def functionExists(db: String, funcName: String): Boolean = synchronized {
    functionExistsCalledTimes = functionExistsCalledTimes + 1
    true
  }

  def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes
} 
Example 25
Source File: OptimizerStructuralIntegrityCheckerSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf


class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {

  object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
      case Project(projectList, child) =>
        val newAttr = UnresolvedAttribute("unresolvedAttr")
        Project(projectList ++ Seq(newAttr), child)
    }
  }

  object Optimize extends Optimizer(
    new SessionCatalog(
      new InMemoryCatalog,
      EmptyFunctionRegistry,
      new SQLConf())) {
    val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
    override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches
  }

  test("check for invalid plan after execution of rule") {
    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
    assert(analyzed.resolved)
    val message = intercept[TreeNodeException[LogicalPlan]] {
      Optimize.execute(analyzed)
    }.getMessage
    val ruleName = OptimizeRuleBreakSI.ruleName
    assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
    assert(message.contains("the structural integrity of the plan is broken"))
  }
} 
Example 26
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 27
Source File: EliminateSortsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}

class EliminateSortsSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Eliminate Sorts", FixedPoint(10),
        FoldablePropagation,
        EliminateSorts) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Empty order by clause") {
    val x = testRelation

    val query = x.orderBy()
    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = x.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("All the SortOrder are no-op") {
    val x = testRelation

    val query = x.orderBy(SortOrder(3, Ascending), SortOrder(-1, Ascending))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(x)

    comparePlans(optimized, correctAnswer)
  }

  test("Partial order-by clauses contain no-op SortOrder") {
    val x = testRelation

    val query = x.orderBy(SortOrder(3, Ascending), 'a.asc)
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(x.orderBy('a.asc))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove no-op alias") {
    val x = testRelation

    val query = x.select('a.as('x), Year(CurrentDate()).as('y), 'b)
      .orderBy('x.asc, 'y.asc, 'b.desc)
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(
      x.select('a.as('x), Year(CurrentDate()).as('y), 'b).orderBy('x.asc, 'b.desc))

    comparePlans(optimized, correctAnswer)
  }
} 
Example 28
Source File: AggregateOptimizeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 29
Source File: SparkPlanner.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val experimentalMethods: ExperimentalMethods)
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  override def strategies: Seq[Strategy] =
    experimentalMethods.extraStrategies ++
      extraPlanningStrategies ++ (
      PythonEvals ::
      DataSourceV2Strategy ::
      FileSourceStrategy ::
      DataSourceStrategy(conf) ::
      SpecialLimits ::
      Aggregation ::
      Window ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  
  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =
      prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]])
      filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList, filterCondition.map(FilterExec(_, scan)).getOrElse(scan))
    }
  }
} 
Example 30
Source File: OrcOptions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.orc

import java.util.Locale

import org.apache.orc.OrcConf.COMPRESS

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val compressionCodec: String = {
    // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec`
    // are in order of precedence from highest to lowest.
    val orcCompressionConf = parameters.get(COMPRESS.getAttribute)
    val codecName = parameters
      .get("compression")
      .orElse(orcCompressionConf)
      .getOrElse(sqlConf.orcCompressionCodec)
      .toLowerCase(Locale.ROOT)
    if (!shortOrcCompressionCodecNames.contains(codecName)) {
      val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT))
      throw new IllegalArgumentException(s"Codec [$codecName] " +
        s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
    }
    shortOrcCompressionCodecNames(codecName)
  }
}

object OrcOptions {
  // The ORC compression short names
  private val shortOrcCompressionCodecNames = Map(
    "none" -> "NONE",
    "uncompressed" -> "NONE",
    "snappy" -> "SNAPPY",
    "zlib" -> "ZLIB",
    "lzo" -> "LZO")

  def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name)
} 
Example 31
Source File: DataSourceV2Utils.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import java.util.regex.Pattern

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport}

private[sql] object DataSourceV2Utils extends Logging {

  
  def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match {
    case cs: SessionConfigSupport =>
      val keyPrefix = cs.keyPrefix()
      require(keyPrefix != null, "The data source config key prefix can't be null.")

      val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)")

      conf.getAllConfs.flatMap { case (key, value) =>
        val m = pattern.matcher(key)
        if (m.matches() && m.groupCount() > 0) {
          Seq((m.group(1), value))
        } else {
          Seq.empty
        }
      }

    case _ => Map.empty
  }
} 
Example 32
Source File: FailureSafeParser.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class FailureSafeParser[IN](
    rawParser: IN => Seq[InternalRow],
    mode: ParseMode,
    schema: StructType,
    columnNameOfCorruptRecord: String) {

  private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
  private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
  private val resultRow = new GenericInternalRow(schema.length)
  private val nullResult = new GenericInternalRow(schema.length)

  // This function takes 2 parameters: an optional partial result, and the bad record. If the given
  // schema doesn't contain a field for corrupted record, we just return the partial result or a
  // row with all fields null. If the given schema contains a field for corrupted record, we will
  // set the bad record to this field, and set other fields according to the partial result or null.
  private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = {
    if (corruptFieldIndex.isDefined) {
      (row, badRecord) => {
        var i = 0
        while (i < actualSchema.length) {
          val from = actualSchema(i)
          resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull
          i += 1
        }
        resultRow(corruptFieldIndex.get) = badRecord()
        resultRow
      }
    } else {
      (row, _) => row.getOrElse(nullResult)
    }
  }

  def parse(input: IN): Iterator[InternalRow] = {
    try {
      rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
    } catch {
      case e: BadRecordException => mode match {
        case PermissiveMode =>
          Iterator(toResultRow(e.partialResult(), e.record))
        case DropMalformedMode =>
          Iterator.empty
        case FailFastMode =>
          throw new SparkException("Malformed records are detected in record parsing. " +
            s"Parse Mode: ${FailFastMode.name}.", e.cause)
      }
    }
  }
} 
Example 33
Source File: SQLHadoopMapReduceCommitProtocol.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql.internal.SQLConf


class SQLHadoopMapReduceCommitProtocol(
    jobId: String,
    path: String,
    dynamicPartitionOverwrite: Boolean = false)
  extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite)
    with Serializable with Logging {

  override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
    var committer = super.setupCommitter(context)

    val configuration = context.getConfiguration
    val clazz =
      configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])

    if (clazz != null) {
      logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")

      // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
      // has an associated output committer. To override this output committer,
      // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
      // If a data source needs to override the output committer, it needs to set the
      // output committer in prepareForWrite method.
      if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) {
        // The specified output committer is a FileOutputCommitter.
        // So, we will use the FileOutputCommitter-specified constructor.
        val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
        committer = ctor.newInstance(new Path(path), context)
      } else {
        // The specified output committer is just an OutputCommitter.
        // So, we will use the no-argument constructor.
        val ctor = clazz.getDeclaredConstructor()
        committer = ctor.newInstance()
      }
    }
    logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}")
    committer
  }
} 
Example 34
Source File: ParquetOptions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.util.Locale

import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val mergeSchema: Boolean = parameters
    .get(MERGE_SCHEMA)
    .map(_.toBoolean)
    .getOrElse(sqlConf.isParquetSchemaMergingEnabled)
}


object ParquetOptions {
  val MERGE_SCHEMA = "mergeSchema"

  // The parquet compression short names
  private val shortParquetCompressionCodecNames = Map(
    "none" -> CompressionCodecName.UNCOMPRESSED,
    "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
    "snappy" -> CompressionCodecName.SNAPPY,
    "gzip" -> CompressionCodecName.GZIP,
    "lzo" -> CompressionCodecName.LZO,
    "lz4" -> CompressionCodecName.LZ4,
    "brotli" -> CompressionCodecName.BROTLI,
    "zstd" -> CompressionCodecName.ZSTD)

  def getParquetCompressionCodecName(name: String): String = {
    shortParquetCompressionCodecNames(name).name()
  }
} 
Example 35
Source File: SaveIntoDataSourceCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider


case class SaveIntoDataSourceCommand(
    query: LogicalPlan,
    dataSource: CreatableRelationProvider,
    options: Map[String, String],
    mode: SaveMode) extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    dataSource.createRelation(
      sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))

    Seq.empty[Row]
  }

  override def simpleString: String = {
    val redacted = SQLConf.get.redactOptions(options)
    s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
  }
} 
Example 36
Source File: Exchange.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 37
Source File: subquery.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}


case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls.
    val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
    plan transformAllExpressions {
      case sub: ExecSubqueryExpression =>
        val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
        val sameResult = sameSchema.find(_.sameResult(sub.plan))
        if (sameResult.isDefined) {
          sub.withNewPlan(sameResult.get)
        } else {
          sameSchema += sub.plan
          sub
        }
    }
  }
} 
Example 38
Source File: ArrowUtils.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.arrow

import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

object ArrowUtils {

  val rootAllocator = new RootAllocator(Long.MaxValue)

  // todo: support more types.

  
  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
    val timeZoneConf = if (conf.pandasRespectSessionTimeZone) {
      Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
    } else {
      Nil
    }
    val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
      conf.pandasGroupedMapAssignColumnsByName.toString)
    Map(timeZoneConf ++ pandasColsByName: _*)
  }
} 
Example 39
Source File: FileStreamSinkLog.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.net.URI

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: Int,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val defaultCompactInterval =
    sparkSession.sessionState.conf.fileSinkLogCompactInterval

  require(defaultCompactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " +
      "to a positive value.")

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = 1
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 40
Source File: ContinuousShuffleReadRDD.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.continuous.shuffle

import java.util.UUID

import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.NextIterator

case class ContinuousShuffleReadPartition(
      index: Int,
      endpointName: String,
      queueSize: Int,
      numShuffleWriters: Int,
      epochIntervalMs: Long)
    extends Partition {
  // Initialized only on the executor, and only once even as we call compute() multiple times.
  lazy val (reader: ContinuousShuffleReader, endpoint) = {
    val env = SparkEnv.get.rpcEnv
    val receiver = new RPCContinuousShuffleReader(
      queueSize, numShuffleWriters, epochIntervalMs, env)
    val endpoint = env.setupEndpoint(endpointName, receiver)

    TaskContext.get().addTaskCompletionListener[Unit] { ctx =>
      env.stop(endpoint)
    }
    (receiver, endpoint)
  }
}


class ContinuousShuffleReadRDD(
    sc: SparkContext,
    numPartitions: Int,
    queueSize: Int = 1024,
    numShuffleWriters: Int = 1,
    epochIntervalMs: Long = 1000,
    val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}"))
  extends RDD[UnsafeRow](sc, Nil) {

  override protected def getPartitions: Array[Partition] = {
    (0 until numPartitions).map { partIndex =>
      ContinuousShuffleReadPartition(
        partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs)
    }.toArray
  }

  override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
    split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
  }
} 
Example 41
Source File: DataSourceWriteBenchmark.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.benchmark

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Benchmark

trait DataSourceWriteBenchmark {
  val conf = new SparkConf()
    .setAppName("DataSourceWriteBenchmark")
    .setIfMissing("spark.master", "local[1]")
    .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")

  val spark = SparkSession.builder.config(conf).getOrCreate()

  val tempTable = "temp"
  val numRows = 1024 * 1024 * 15

  def withTempTable(tableNames: String*)(f: => Unit): Unit = {
    try f finally tableNames.foreach(spark.catalog.dropTempView)
  }

  def withTable(tableNames: String*)(f: => Unit): Unit = {
    try f finally {
      tableNames.foreach { name =>
        spark.sql(s"DROP TABLE IF EXISTS $name")
      }
    }
  }

  def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = {
    spark.sql(s"create table $table(id $dataType) using $format")
    benchmark.addCase(s"Output Single $dataType Column") { _ =>
      spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable")
    }
  }

  def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = {
    spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format")
    benchmark.addCase("Output Int and String Column") { _ =>
      spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " +
        s"c1, CAST(id AS STRING) AS c2 FROM $tempTable")
    }
  }

  def writePartition(table: String, format: String, benchmark: Benchmark): Unit = {
    spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)")
    benchmark.addCase("Output Partitions") { _ =>
      spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," +
        s" CAST(id % 2 AS INT) AS p FROM $tempTable")
    }
  }

  def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = {
    spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS")
    benchmark.addCase("Output Buckets") { _ =>
      spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " +
        s"c1, CAST(id AS INT) AS c2 FROM $tempTable")
    }
  }

  def runBenchmark(format: String): Unit = {
    val tableInt = "tableInt"
    val tableDouble = "tableDouble"
    val tableIntString = "tableIntString"
    val tablePartition = "tablePartition"
    val tableBucket = "tableBucket"
    withTempTable(tempTable) {
      spark.range(numRows).createOrReplaceTempView(tempTable)
      withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) {
        val benchmark = new Benchmark(s"$format writer benchmark", numRows)
        writeNumeric(tableInt, format, benchmark, "Int")
        writeNumeric(tableDouble, format, benchmark, "Double")
        writeIntString(tableIntString, format, benchmark)
        writePartition(tablePartition, format, benchmark)
        writeBucket(tableBucket, format, benchmark)
        benchmark.run()
      }
    }
  }
} 
Example 42
Source File: ParquetFileFormatSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLContext {

  test("read parquet footers in parallel") {
    def testReadFooters(ignoreCorruptFiles: Boolean): Unit = {
      withTempDir { dir =>
        val fs = FileSystem.get(spark.sessionState.newHadoopConf())
        val basePath = dir.getCanonicalPath

        val path1 = new Path(basePath, "first")
        val path2 = new Path(basePath, "second")
        val path3 = new Path(basePath, "third")

        spark.range(1).toDF("a").coalesce(1).write.parquet(path1.toString)
        spark.range(1, 2).toDF("a").coalesce(1).write.parquet(path2.toString)
        spark.range(2, 3).toDF("a").coalesce(1).write.json(path3.toString)

        val fileStatuses =
          Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten

        val footers = ParquetFileFormat.readParquetFootersInParallel(
          spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles)

        assert(footers.size == 2)
      }
    }

    testReadFooters(true)
    val exception = intercept[SparkException] {
      testReadFooters(false)
    }.getCause
    assert(exception.getMessage().contains("Could not read footer for file"))
  }
} 
Example 43
Source File: DataSourceScanExecRedactionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {

  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.redaction.string.regex", "file:/[\\w_]+")

  test("treeString is redacted") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)

      val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
        .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head
      assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/")))

      assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.toString.contains(rootPath.getName))
      assert(!df.queryExecution.simpleString.contains(rootPath.getName))

      val replacement = "*********"
      assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.toString.contains(replacement))
      assert(df.queryExecution.simpleString.contains(replacement))
    }
  }

  private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = {
    queryExecution.toString.contains(msg) ||
    queryExecution.simpleString.contains(msg) ||
    queryExecution.stringWithStats.contains(msg)
  }

  test("explain is redacted using SQLConf") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)
      val replacement = "*********"

      // Respect SparkConf and replace file:/
      assert(isIncluded(df.queryExecution, replacement))

      assert(isIncluded(df.queryExecution, "FileScan"))
      assert(!isIncluded(df.queryExecution, "file:/"))

      withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") {
        // Respect SQLConf and replace FileScan
        assert(isIncluded(df.queryExecution, replacement))

        assert(!isIncluded(df.queryExecution, "FileScan"))
        assert(isIncluded(df.queryExecution, "file:/"))
      }
    }
  }

} 
Example 44
Source File: DataSourceV2UtilsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.v2

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.internal.SQLConf

class DataSourceV2UtilsSuite extends SparkFunSuite {

  private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix

  test("method withSessionConfig() should propagate session configs correctly") {
    // Only match configs with keys start with "spark.datasource.${keyPrefix}".
    val conf = new SQLConf
    conf.setConfString(s"spark.datasource.$keyPrefix.foo.bar", "false")
    conf.setConfString(s"spark.datasource.$keyPrefix.whateverConfigName", "123")
    conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false")
    conf.setConfString("spark.datasource.another.config.name", "123")
    conf.setConfString(s"spark.datasource.$keyPrefix.", "123")
    val cs = classOf[DataSourceV2WithSessionConfig].newInstance()
    val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf)
    assert(confs.size == 2)
    assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0)
    assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0)
    assert(confs.keySet.contains("foo.bar"))
    assert(confs.keySet.contains("whateverConfigName"))
  }
}

class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport {

  override def keyPrefix: String = "userDefinedDataSource"
} 
Example 45
Source File: DataSourceTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] abstract class DataSourceTest extends QueryTest {

  protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) {
    test(sqlString) {
      withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) {
        checkAnswer(spark.sql(sqlString), expectedAnswer)
      }
    }
  }

}

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(
      parameters("from").toInt,
      parameters("TO").toInt,
      parameters("Table"))(sqlContext.sparkSession)
  }
}

case class SimpleDDLScan(
    from: Int,
    to: Int,
    table: String)(@transient val sparkSession: SparkSession)
  extends BaseRelation with TableScan {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sparkSession.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
} 
Example 46
Source File: SharedSparkSession.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import scala.concurrent.duration._

import org.scalatest.{BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.internal.SQLConf


  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_spark != null) {
          try {
            _spark.sessionState.catalog.reset()
          } finally {
            _spark.stop()
            _spark = null
          }
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(10.seconds), interval(2.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 47
Source File: TestSQLContext.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf}


  val overrideConfs: Map[String, String] =
    Map(
      // Fewer shuffle partitions to speed up testing.
      SQLConf.SHUFFLE_PARTITIONS.key -> "5")
}

private[sql] class TestSQLSessionStateBuilder(
    session: SparkSession,
    state: Option[SessionState])
  extends SessionStateBuilder(session, state) with WithTestConf {
  override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs
  override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _)
} 
Example 48
Source File: CodeGeneration.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package com.intel.sparkColumnarPlugin.expression

import org.apache.arrow.vector.Float4Vector
import org.apache.arrow.vector.IntVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object CodeGeneration {
  val timeZoneId = SQLConf.get.sessionLocalTimeZone

  def getResultType(left: ArrowType, right: ArrowType): ArrowType = {
    if (left.equals(right)) {
      left
    } else {
      val left_precise_level = getPreciseLevel(left)
      val right_precise_level = getPreciseLevel(right)
      if (left_precise_level > right_precise_level)
        left
      else
        right
    }
  }

  def getResultType(dataType: DataType): ArrowType = {
    dataType match {
      case other =>
        ArrowUtils.toArrowType(dataType, timeZoneId)
    }
    
  }

  def getResultType(): ArrowType = {
    new ArrowType.Int(32, true)
  }

  def getPreciseLevel(dataType: ArrowType): Int = {
    dataType match {
      case t: ArrowType.Int =>
        4
      case t: ArrowType.FloatingPoint =>
        8
      case _ =>
        throw new UnsupportedOperationException(s"Unable to get precise level of $dataType ${dataType.getClass}.")
    }
  }

  def getCastFuncName(dataType: ArrowType): String = {
    dataType match {
      case t: ArrowType.FloatingPoint =>
        s"castFLOAT${4 * dataType.asInstanceOf[ArrowType.FloatingPoint].getPrecision().getFlatbufID()}"
      case _ =>
        throw new UnsupportedOperationException(s"getCastFuncName(${dataType}) is not supported.")
    }
  }
} 
Example 49
Source File: FilterHelper.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.utils

import org.apache.hadoop.conf.Configuration
import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
import org.apache.parquet.hadoop.ParquetInputFormat

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.parquet.ParquetFiltersWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

object FilterHelper {

  def tryToPushFilters(
      sparkSession: SparkSession,
      requiredSchema: StructType,
      filters: Seq[Filter]): Option[FilterPredicate] = {
    tryToPushFilters(sparkSession.sessionState.conf, requiredSchema, filters)
  }

  def tryToPushFilters(
      conf: SQLConf,
      requiredSchema: StructType,
      filters: Seq[Filter]): Option[FilterPredicate] = {
    if (conf.parquetFilterPushDown) {
      filters
        // Collects all converted Parquet filter predicates. Notice that not all predicates can be
        // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
        // is used here.
        .flatMap(ParquetFiltersWrapper.createFilter(conf, requiredSchema, _))
        .reduceOption(FilterApi.and)
    } else {
      None
    }
  }

  def setFilterIfExist(configuration: Configuration, pushed: Option[FilterPredicate]): Unit = {
    pushed match {
      case Some(filters) => ParquetInputFormat.setFilterPredicate(configuration, filters)
      case _ => // do nothing
    }
  }
} 
Example 50
Source File: ParquetFiltersWrapper.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.filter2.predicate.FilterPredicate

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType


object ParquetFiltersWrapper {
  def createFilter(
      conf: SQLConf, schema: StructType,
      predicate: sources.Filter): Option[FilterPredicate] = {
    val parquetFilters =
      new ParquetFilters(new SparkToParquetSchemaConverter(conf).convert(schema),
      conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp,
      conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith,
      conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis)
    parquetFilters.createFilter(predicate)
  }
} 
Example 51
Source File: OapQuerySuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import java.util.{Locale, TimeZone}

import org.scalatest.{BeforeAndAfter, Ignore}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.internal.SQLConf

// Ignore because in separate package will encounter problem with shaded spark source.
@Ignore
class OapQuerySuite extends HiveComparisonTest with BeforeAndAfter  {
  private lazy val originalTimeZone = TimeZone.getDefault
  private lazy val originalLocale = Locale.getDefault
  import org.apache.spark.sql.hive.test.TestHive._

  // Note: invoke TestHive will create a SparkContext which can't be configured by us.
  // So be careful this may affect current using SparkContext and cause strange problem.
  private lazy val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled

  override def beforeAll() {
    super.beforeAll()
    TestHive.setCacheTables(true)
    // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
    TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
    // Add Locale setting
    Locale.setDefault(Locale.US)
    // Ensures that cross joins are enabled so that we can test them
    TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
    TestHive.setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true)
  }

  override def afterAll() {
    try {
      TestHive.setCacheTables(false)
      TimeZone.setDefault(originalTimeZone)
      Locale.setDefault(originalLocale)
      sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
      TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
    } finally {
      super.afterAll()
    }
  }
  private def assertDupIndex(body: => Unit): Unit = {
    val e = intercept[AnalysisException] { body }
    assert(e.getMessage.toLowerCase.contains("exists"))
  }

  test("create hive table in parquet format") {
    try {
      sql("create table p_table (key int, val string) stored as parquet")
      sql("insert overwrite table p_table select * from src")
      sql("create oindex if not exists p_index on p_table(key)")
      assert(sql("select val from p_table where key = 238")
        .collect().head.getString(0) == "val_238")
    } finally {
      sql("drop oindex p_index on p_table")
      sql("drop table p_table")
    }
  }

  test("create duplicate hive table in parquet format") {
    try {
      sql("create table p_table1 (key int, val string) stored as parquet")
      sql("insert overwrite table p_table1 select * from src")
      sql("create oindex p_index on p_table1(key)")
      assertDupIndex { sql("create oindex p_index on p_table1(key)") }
    } finally {
      sql("drop oindex p_index on p_table1")
    }
  }
} 
Example 52
Source File: FilterHelperSuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.utils

import org.apache.hadoop.conf.Configuration
import org.apache.parquet.hadoop.ParquetInputFormat

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

class FilterHelperSuite extends SparkFunSuite {
  val conf = SQLConf.get

  test("Pushed And Set") {
    val requiredSchema = new StructType()
      .add(StructField("a", IntegerType))
      .add(StructField("b", StringType))
    val filters = Seq(GreaterThan("a", 1), EqualTo("b", "2"))
    val expected = s"""and(gt(a, 1), eq(b, Binary{"2"}))"""
    conf.setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, true)
    val pushed = FilterHelper.tryToPushFilters(conf, requiredSchema, filters)
    assert(pushed.isDefined)
    assert(pushed.get.toString.equals(expected))
    val config = new Configuration()
    FilterHelper.setFilterIfExist(config, pushed)
    val humanReadable = config.get(ParquetInputFormat.FILTER_PREDICATE + ".human.readable")
    assert(humanReadable.nonEmpty)
    assert(humanReadable.equals(expected))
  }

  test("Not Pushed") {
    val requiredSchema = new StructType()
      .add(StructField("a", IntegerType))
      .add(StructField("b", StringType))
    val filters = Seq(GreaterThan("a", 1), EqualTo("b", "2"))
    conf.setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, false)
    val pushed = FilterHelper.tryToPushFilters(conf, requiredSchema, filters)
    assert(pushed.isEmpty)
    val config = new Configuration()
    FilterHelper.setFilterIfExist(config, pushed)
    assert(config.get(ParquetInputFormat.FILTER_PREDICATE) == null)
    assert(config.get(ParquetInputFormat.FILTER_PREDICATE + ".human.readable") == null)
  }
} 
Example 53
Source File: OapSharedSQLContext.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import scala.concurrent.duration._

import org.scalatest.{BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.oap.OapRuntime


trait OapSharedSQLContext extends SQLTestUtils with OapSharedSparkSession


  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_spark != null) {
          try {
            _spark.sessionState.catalog.reset()
          } finally {
            OapRuntime.stop()
            _spark.stop()
            _spark = null
          }
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(10.seconds), interval(2.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 54
Source File: ArrowSQLConf.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package com.intel.oap.spark.sql.execution.datasources.v2.arrow

import org.apache.spark.sql.internal.SQLConf

object ArrowSQLConf {
  val ARROW_FILTER_PUSHDOWN_ENABLED = SQLConf.buildConf("spark.sql.arrow.filterPushdown")
    .doc("Enables Arrow filter push-down optimization when set to true.")
    .booleanConf
    .createWithDefault(true)

  implicit def fromSQLConf(c: SQLConf): ArrowSQLConf = {
    new ArrowSQLConf(c)
  }
}

class ArrowSQLConf(c: SQLConf) {
  def arrowFilterPushDown: Boolean = c.getConf(ArrowSQLConf.ARROW_FILTER_PUSHDOWN_ENABLED)
} 
Example 55
Source File: SparkPlanner.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan
    }
  }

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.
    plans
  }

  
  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =
      prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]])
      filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList, filterCondition.map(FilterExec(_, scan)).getOrElse(scan))
    }
  }
} 
Example 56
Source File: CartesianProductExec.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter


class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
  extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {

  override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
    // We will not sort the rows, so prefixComparator and recordComparator are null.
    val sorter = UnsafeExternalSorter.create(
      context.taskMemoryManager(),
      SparkEnv.get.blockManager,
      SparkEnv.get.serializerManager,
      context,
      null,
      null,
      1024,
      SparkEnv.get.memoryManager.pageSizeBytes,
      SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
      false)

    val partition = split.asInstanceOf[CartesianPartition]
    for (y <- rdd2.iterator(partition.s2, context)) {
      sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
    }

    // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
    def createIter(): Iterator[UnsafeRow] = {
      val iter = sorter.getIterator
      val unsafeRow = new UnsafeRow(numFieldsOfRight)
      new Iterator[UnsafeRow] {
        override def hasNext: Boolean = {
          iter.hasNext
        }
        override def next(): UnsafeRow = {
          iter.loadNext()
          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
          unsafeRow
        }
      }
    }

    val resultIter =
      for (x <- rdd1.iterator(partition.s1, context);
           y <- createIter()) yield (x, y)
    CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
      resultIter, sorter.cleanupResources())
  }
}


case class CartesianProductExec(
    left: SparkPlan,
    right: SparkPlan,
    condition: Option[Expression]) extends BinaryExecNode {
  override def output: Seq[Attribute] = left.output ++ right.output

  override lazy val metrics = Map(
    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")

    val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
    val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]

    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
    pair.mapPartitionsWithIndexInternal { (index, iter) =>
      val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
      val filtered = if (condition.isDefined) {
        val boundCondition = newPredicate(condition.get, left.output ++ right.output)
        boundCondition.initialize(index)
        val joined = new JoinedRow

        iter.filter { r =>
          boundCondition.eval(joined(r._1, r._2))
        }
      } else {
        iter
      }
      filtered.map { r =>
        numOutputRows += 1
        joiner.join(r._1, r._2)
      }
    }
  }
} 
Example 57
Source File: ParquetOptions.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val mergeSchema: Boolean = parameters
    .get(MERGE_SCHEMA)
    .map(_.toBoolean)
    .getOrElse(sqlConf.isParquetSchemaMergingEnabled)
}


object ParquetOptions {
  val MERGE_SCHEMA = "mergeSchema"

  // The parquet compression short names
  private val shortParquetCompressionCodecNames = Map(
    "none" -> CompressionCodecName.UNCOMPRESSED,
    "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
    "snappy" -> CompressionCodecName.SNAPPY,
    "gzip" -> CompressionCodecName.GZIP,
    "lzo" -> CompressionCodecName.LZO)
} 
Example 58
Source File: Exchange.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 59
Source File: FileStreamSinkLog.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
import org.json4s.jackson.Serialization.{read, write}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: String,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val defaultCompactInterval =
    sparkSession.sessionState.conf.fileSinkLogCompactInterval

  require(defaultCompactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " +
      "to a positive value.")

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = "v1"
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 60
Source File: SparkOptimizer.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
    catalog: SessionCatalog,
    conf: SQLConf,
    experimentalMethods: ExperimentalMethods)
  extends Optimizer(catalog, conf) {

  override def batches: Seq[Batch] = super.batches :+
    Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
    Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
    Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
    Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
} 
Example 61
Source File: CarbonSqlAstBuilder.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.parser.ParserUtils.string
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{AddTableColumnsContext, CreateHiveTableContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.parser.{CarbonHelperSqlAstBuilder, CarbonSpark2SqlParser, CarbonSparkSqlParserUtil}

class CarbonSqlAstBuilder(conf: SQLConf, parser: CarbonSpark2SqlParser, sparkSession: SparkSession)
  extends SparkSqlAstBuilder(conf) with SqlAstBuilderHelper {

  val helper = new CarbonHelperSqlAstBuilder(conf, parser, sparkSession)

  override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = {
    val fileStorage = CarbonSparkSqlParserUtil.getFileStorage(ctx.createFileFormat(0))

    if (fileStorage.equalsIgnoreCase("'carbondata'") ||
        fileStorage.equalsIgnoreCase("carbondata") ||
        fileStorage.equalsIgnoreCase("'carbonfile'") ||
        fileStorage.equalsIgnoreCase("'org.apache.carbondata.format'")) {
      val createTableTuple = (ctx.createTableHeader, ctx.skewSpec(0),
        ctx.bucketSpec(0), ctx.partitionColumns, ctx.columns, ctx.tablePropertyList(0),
        ctx.locationSpec(0), Option(ctx.STRING(0)).map(string), ctx.AS, ctx.query, fileStorage)
      helper.createCarbonTable(createTableTuple)
    } else {
      super.visitCreateHiveTable(ctx)
    }
  }

  override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = {
    visitAddTableColumns(parser, ctx)
  }
} 
Example 62
Source File: CarbonAnalyzer.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CarbonReflectionUtils

class CarbonAnalyzer(catalog: SessionCatalog,
    conf: SQLConf,
    sparkSession: SparkSession,
    analyzer: Analyzer) extends Analyzer(catalog, conf) {

  val mvPlan = try {
    CarbonReflectionUtils.createObject(
      "org.apache.carbondata.mv.extension.MVAnalyzerRule",
      sparkSession)._1.asInstanceOf[Rule[LogicalPlan]]
  } catch {
    case e: Exception =>
      null
  }

  override def execute(plan: LogicalPlan): LogicalPlan = {
    val logicalPlan = analyzer.execute(plan)
    if (mvPlan != null) {
      mvPlan.apply(logicalPlan)
    } else {
      logicalPlan
    }
  }
} 
Example 63
Source File: CarbonExtensionSqlParser.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.parser

import org.apache.spark.sql.{CarbonEnv, CarbonUtils, SparkSession}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CarbonException

import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException
import org.apache.carbondata.spark.util.CarbonScalaUtil


class CarbonExtensionSqlParser(
    conf: SQLConf,
    sparkSession: SparkSession,
    initialParser: ParserInterface
) extends SparkSqlParser(conf) {

  val parser = new CarbonExtensionSpark2SqlParser

  override def parsePlan(sqlText: String): LogicalPlan = {
    parser.synchronized {
      CarbonEnv.getInstance(sparkSession)
    }
    CarbonUtils.updateSessionInfoToCurrentThread(sparkSession)
    try {
      val plan = parser.parse(sqlText)
      plan
    } catch {
      case ce: MalformedCarbonCommandException =>
        throw ce
      case ex: Throwable =>
        try {
          val parsedPlan = initialParser.parsePlan(sqlText)
          CarbonScalaUtil.cleanParserThreadLocals
          parsedPlan
        } catch {
          case mce: MalformedCarbonCommandException =>
            throw mce
          case e: Throwable =>
            e.printStackTrace(System.err)
            CarbonScalaUtil.cleanParserThreadLocals
            CarbonException.analysisException(
              s"""== Parser1: ${parser.getClass.getName} ==
                 |${ex.getMessage}
                 |== Parser2: ${initialParser.getClass.getName} ==
                 |${e.getMessage}
               """.stripMargin.trim)
        }
    }
  }
} 
Example 64
Source File: PreprocessTableDelete.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.commands.DeleteCommand

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


case class PreprocessTableDelete(conf: SQLConf) extends Rule[LogicalPlan] {

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan.resolveOperators {
      case d: DeltaDelete if d.resolved =>
        d.condition.foreach { cond =>
          if (SubqueryExpression.hasSubquery(cond)) {
            throw DeltaErrors.subqueryNotSupportedException("DELETE", cond)
          }
        }
        toCommand(d)
    }
  }

  def toCommand(d: DeltaDelete): DeleteCommand = EliminateSubqueryAliases(d.child) match {
    case DeltaFullTable(tahoeFileIndex) =>
      DeleteCommand(tahoeFileIndex, d.child, d.condition)

    case o =>
      throw DeltaErrors.notADeltaSourceException("DELETE", Some(o))
  }
} 
Example 65
Source File: PreprocessTableUpdate.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.commands.UpdateCommand

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


case class PreprocessTableUpdate(conf: SQLConf)
  extends Rule[LogicalPlan] with UpdateExpressionsSupport {

  override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
    case u: DeltaUpdateTable if u.resolved =>
      u.condition.foreach { cond =>
        if (SubqueryExpression.hasSubquery(cond)) {
          throw DeltaErrors.subqueryNotSupportedException("UPDATE", cond)
        }
      }
      toCommand(u)
  }

  def toCommand(update: DeltaUpdateTable): UpdateCommand = {
    val index = EliminateSubqueryAliases(update.child) match {
      case DeltaFullTable(tahoeFileIndex) =>
        tahoeFileIndex
      case o =>
        throw DeltaErrors.notADeltaSourceException("UPDATE", Some(o))
    }

    val targetColNameParts = update.updateColumns.map(DeltaUpdateTable.getTargetColNameParts(_))
    val alignedUpdateExprs = generateUpdateExpressions(
      update.child.output, targetColNameParts, update.updateExpressions, conf.resolver)
    UpdateCommand(index, update.child, alignedUpdateExprs, update.condition)
  }
} 
Example 66
Source File: DeltaTableIdentifier.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.sources.DeltaSourceUtils
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.internal.SQLConf


  def apply(spark: SparkSession, identifier: TableIdentifier): Option[DeltaTableIdentifier] = {
    if (isDeltaPath(spark, identifier)) {
      Some(DeltaTableIdentifier(path = Option(identifier.table)))
    } else if (DeltaTableUtils.isDeltaTable(spark, identifier)) {
      Some(DeltaTableIdentifier(table = Option(identifier)))
    } else {
      None
    }
  }
} 
Example 67
Source File: DeltaSparkSessionExtension.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package io.delta.sql

import org.apache.spark.sql.delta._
import io.delta.sql.parser.DeltaSqlParser

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.internal.SQLConf


class DeltaSparkSessionExtension extends (SparkSessionExtensions => Unit) {
  override def apply(extensions: SparkSessionExtensions): Unit = {
    extensions.injectParser { (session, parser) =>
      new DeltaSqlParser(parser)
    }
    extensions.injectResolutionRule { session =>
      new DeltaAnalysis(session, session.sessionState.conf)
    }
    extensions.injectCheckRule { session =>
      new DeltaUnsupportedOperationsCheck(session)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableUpdate(session.sessionState.conf)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableMerge(session.sessionState.conf)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableDelete(session.sessionState.conf)
    }
  }
} 
Example 68
Source File: EvolvabilitySuite.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.functions.typedLit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils

class EvolvabilitySuite extends EvolvabilitySuiteBase with SQLTestUtils {

  import testImplicits._

  test("delta 0.1.0") {
    testEvolvability("src/test/resources/delta/delta-0.1.0")
  }

  test("delta 0.1.0 - case sensitivity enabled") {
    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
      testEvolvability("src/test/resources/delta/delta-0.1.0")
    }
  }

  testQuietly("future proofing against new features") {
    val tempDir = Utils.createTempDir().toString
    Seq(1, 2, 3).toDF().write.format("delta").save(tempDir)

    val deltaLog = DeltaLog.forTable(spark, tempDir)
    deltaLog.store.write(new Path(deltaLog.logPath, "00000000000000000001.json"),
      Iterator("""{"some_new_feature":{"a":1}}"""))

    // Shouldn't fail here
    deltaLog.update()

    val sq = spark.readStream.format("delta").load(tempDir.toString)
      .groupBy()
      .count()
      .writeStream
      .outputMode("complete")
      .format("console")
      .start()

    // Also shouldn't fail
    sq.processAllAvailable()
    Seq(1, 2, 3).toDF().write.format("delta").mode("append").save(tempDir)
    sq.processAllAvailable()
    deltaLog.store.write(new Path(deltaLog.logPath, "00000000000000000003.json"),
      Iterator("""{"some_new_feature":{"a":1}}"""))
    sq.processAllAvailable()
    sq.stop()
  }

  test("serialized partition values must contain null values") {
    val tempDir = Utils.createTempDir().toString
    val df1 = spark.range(5).withColumn("part", typedLit[String](null))
    val df2 = spark.range(5).withColumn("part", typedLit("1"))
    df1.union(df2).coalesce(1).write.partitionBy("part").format("delta").save(tempDir)

    // Clear the cache
    DeltaLog.clearCache()
    val deltaLog = DeltaLog.forTable(spark, tempDir)

    val dataThere = deltaLog.snapshot.allFiles.collect().forall { addFile =>
      if (!addFile.partitionValues.contains("part")) {
        fail(s"The partition values: ${addFile.partitionValues} didn't contain the column 'part'.")
      }
      val value = addFile.partitionValues("part")
      value === null || value === "1"
    }

    assert(dataThere, "Partition values didn't match with null or '1'")

    // Check serialized JSON as well
    val contents = deltaLog.store.read(FileNames.deltaFile(deltaLog.logPath, 0L))
    assert(contents.exists(_.contains(""""part":null""")), "null value should be written in json")
  }

  testQuietly("parse old version CheckpointMetaData") {
    assert(JsonUtils.mapper.readValue[CheckpointMetaData]("""{"version":1,"size":1}""")
      == CheckpointMetaData(1, 1, None))
  }
} 
Example 69
Source File: DeltaHiveTest.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta.test

import org.apache.spark.sql.delta.catalog.DeltaCatalog
import io.delta.sql.DeltaSparkSessionExtension
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SQLTestUtils


trait DeltaHiveTest extends SparkFunSuite with BeforeAndAfterAll { self: SQLTestUtils =>

  private var _session: SparkSession = _
  private var _hiveContext: TestHiveContext = _
  private var _sc: SparkContext = _

  override def beforeAll(): Unit = {
    val conf = TestHive.sparkSession.sparkContext.getConf.clone()
    TestHive.sparkSession.stop()
    conf.set(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[DeltaCatalog].getName)
    conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key,
      classOf[DeltaSparkSessionExtension].getName)
    _sc = new SparkContext("local", this.getClass.getName, conf)
    _hiveContext = new TestHiveContext(_sc)
    _session = _hiveContext.sparkSession
    SparkSession.setActiveSession(_session)
    super.beforeAll()
  }

  override protected def spark: SparkSession = _session

  override def afterAll(): Unit = {
    try {
      _hiveContext.reset()
    } finally {
      _sc.stop()
    }
  }
} 
Example 70
Source File: SparkPlanner.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def user: String = sparkContext.sparkUser
  
  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy(user) ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  override protected def collectPlaceholders(
     plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
    plan.collect {
      case placeholder @ PlanLater(logicalPlan, _) => placeholder -> logicalPlan
    }
  }

  override protected def prunePlans(plans: Iterator[SparkPlan]): Iterator[SparkPlan] = {
    // TODO: We will need to prune bad plans when we improve plan space exploration
    //       to prevent combinatorial explosion.
    plans
  }

  
  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =
      prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]])
      filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList, filterCondition.map(FilterExec(_, scan)).getOrElse(scan))
    }
  }
} 
Example 71
Source File: ParquetOptions.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val mergeSchema: Boolean = parameters
    .get(MERGE_SCHEMA)
    .map(_.toBoolean)
    .getOrElse(sqlConf.isParquetSchemaMergingEnabled)
}


object ParquetOptions {
  val MERGE_SCHEMA = "mergeSchema"

  // The parquet compression short names
  private val shortParquetCompressionCodecNames = Map(
    "none" -> CompressionCodecName.UNCOMPRESSED,
    "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
    "snappy" -> CompressionCodecName.SNAPPY,
    "gzip" -> CompressionCodecName.GZIP,
    "lzo" -> CompressionCodecName.LZO)
} 
Example 72
Source File: Exchange.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get, plan.user)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 73
Source File: FileStreamSinkLog.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
import org.json4s.jackson.Serialization.{read, write}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: String,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val defaultCompactInterval =
    sparkSession.sessionState.conf.fileSinkLogCompactInterval

  require(defaultCompactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " +
      "to a positive value.")

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = "v1"
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 74
Source File: SparkOptimizer.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
    catalog: SessionCatalog,
    conf: SQLConf,
    experimentalMethods: ExperimentalMethods)
  extends Optimizer(catalog, conf) {

  override def batches: Seq[Batch] = super.batches :+
    Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
    Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
    Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
    Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
} 
Example 75
Source File: ExecuteStatementInClientModeWithHDFSSuite.scala    From kyuubi   with Apache License 2.0 5 votes vote down vote up
package yaooqinn.kyuubi.operation

import java.io.{File, IOException}

import scala.util.Try

import org.apache.hadoop.fs.Path
import org.apache.hadoop.hdfs.{HdfsConfiguration, MiniDFSCluster}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.sql.catalyst.catalog.FunctionResource
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.mockito.Mockito.when

import yaooqinn.kyuubi.operation.statement.ExecuteStatementInClientMode
import yaooqinn.kyuubi.utils.{KyuubiHiveUtil, ReflectUtils}

class ExecuteStatementInClientModeWithHDFSSuite extends ExecuteStatementInClientModeSuite {
  val hdfsConf = new HdfsConfiguration
  hdfsConf.set("fs.hdfs.impl.disable.cache", "true")
  var cluster: MiniDFSCluster = new MiniDFSCluster.Builder(hdfsConf).build()
  cluster.waitClusterUp()
  val fs = cluster.getFileSystem
  val homeDirectory: Path = fs.getHomeDirectory
  private val fileName = "example-1.0.0-SNAPSHOT.jar"
  private val remoteUDFFile = new Path(homeDirectory, fileName)

  override def beforeAll(): Unit = {
    val file = new File(this.getClass.getProtectionDomain.getCodeSource.getLocation + fileName)
    val localUDFFile = new Path(file.getPath)
    fs.copyFromLocalFile(localUDFFile, remoteUDFFile)
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    fs.delete(remoteUDFFile, true)
    fs.close()
    cluster.shutdown()
    super.afterAll()
  }

  test("transform logical plan") {
    val op = sessionMgr.getOperationMgr.newExecuteStatementOperation(session, statement)
      .asInstanceOf[ExecuteStatementInClientMode]
    val parser = new SparkSqlParser(new SQLConf)
    val plan0 = parser.parsePlan(
      s"create temporary function a as 'a.b.c' using file '$remoteUDFFile'")
    val plan1 = op.transform(plan0)
    assert(plan0 === plan1)
    assert(
      ReflectUtils.getFieldValue(plan1, "resources").asInstanceOf[Seq[FunctionResource]].isEmpty)

    val plan2 = parser.parsePlan(
      s"create temporary function a as 'a.b.c' using jar '$remoteUDFFile'")
    val plan3 = op.transform(plan2)
    assert(plan3 === plan2)
    assert(
      ReflectUtils.getFieldValue(plan3, "resources").asInstanceOf[Seq[FunctionResource]].isEmpty)
  }

  test("add delegation token with hive session state, hdfs") {
    val hiveConf = new HiveConf(hdfsConf, classOf[HiveConf])
    val state = new SessionState(hiveConf)
    assert(Try {
      KyuubiHiveUtil.addDelegationTokensToHiveState(state, UserGroupInformation.getCurrentUser)
    }.isSuccess)

    val mockuser = mock[UserGroupInformation]
    when(mockuser.getUserName).thenThrow(classOf[IOException])
    KyuubiHiveUtil.addDelegationTokensToHiveState(state, mockuser)
  }
} 
Example 76
Source File: HBaseSparkSession.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase

import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.hbase.execution.{HBaseSourceAnalysis, HBaseStrategies}
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SQLConf, SessionState, SharedState}

class HBaseSparkSession(sc: SparkContext) extends SparkSession(sc) {
  self =>

  def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)

  @transient
  override lazy val sessionState: SessionState = new HBaseSessionStateBuilder(this).build()

  HBaseConfiguration.merge(
    sc.hadoopConfiguration, HBaseConfiguration.create(sc.hadoopConfiguration))

  @transient
  override lazy val sharedState: SharedState =
    new HBaseSharedState(sc, this.sqlContext)
}

class HBaseSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session) {
  override lazy val conf: SQLConf = new HBaseSQLConf

  override protected def newBuilder: NewBuilder = new HBaseSessionStateBuilder(_, _)

  override lazy val experimentalMethods: ExperimentalMethods = {
    val result = new ExperimentalMethods;
    result.extraStrategies = Seq((new SparkPlanner(session.sparkContext, conf, new ExperimentalMethods)
      with HBaseStrategies).HBaseDataSource)
    result
  }

  override lazy val analyzer: Analyzer = {
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
          new FindDataSourceTable(session) +:
          new ResolveSQLOnFile(session) +:
          customResolutionRules

      override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
          PreprocessTableCreation(session) +:
          PreprocessTableInsertion(conf) +:
          DataSourceAnalysis(conf) +:
          HBaseSourceAnalysis(session) +:
          customPostHocResolutionRules

      override val extendedCheckRules =
        customCheckRules
    }
  }
}

class HBaseSharedState(sc: SparkContext, sqlContext: SQLContext) extends SharedState(sc) {
  override lazy val externalCatalog: ExternalCatalog =
    new HBaseCatalog(sqlContext, sc.hadoopConfiguration)
} 
Example 77
Source File: HBaseSQLConf.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase

import org.apache.spark.sql.internal.SQLConf

object HBaseSQLConf {
  val PARTITION_EXPIRATION = "spark.sql.hbase.partition.expiration"
  val SCANNER_FETCH_SIZE = "spark.sql.hbase.scanner.fetchsize"
  val USE_COPROCESSOR = "spark.sql.hbase.coprocessor"
  val USE_CUSTOMFILTER = "spark.sql.hbase.customfilter"

  val PROVIDER = "provider"
  val HBASE = "hbase"
  val COLS = "cols"
  val KEY_COLS = "keyCols"
  val NONKEY_COLS = "nonKeyCols"
  val HBASE_TABLENAME = "hbaseTableName"
  val ENCODING_FORMAT = "encodingFormat"
}


  private[hbase] def partitionExpiration: Long = getConfString(PARTITION_EXPIRATION, "600").toLong
  private[hbase] def scannerFetchSize: Int = getConfString(SCANNER_FETCH_SIZE, "1000").toInt
  private[hbase] def useCoprocessor: Boolean = getConfString(USE_COPROCESSOR, "false").toBoolean
  private[hbase] def useCustomFilter: Boolean = getConfString(USE_CUSTOMFILTER, "true").toBoolean
} 
Example 78
Source File: HBaseAdvancedSQLQuerySuite.scala    From Heracles   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hbase

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{MetadataBuilder, StructType}
import org.apache.spark.sql.{DataFrame, Row}

class HBaseAdvancedSQLQuerySuite extends TestBaseWithSplitData {
  import org.apache.spark.sql.hbase.TestHbase._
  import org.apache.spark.sql.hbase.TestHbase.implicits._

  test("aggregation with codegen") {
    val originalValue = TestHbase.sessionState.conf.wholeStageEnabled
    TestHbase.sessionState.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
    val result = sql("SELECT col1 FROM ta GROUP BY col1").collect()
    assert(result.length == 14, s"aggregation with codegen test failed on size")
    TestHbase.sessionState.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, originalValue.toString)
  }

  test("dsl simple select 0") {
    val tableA = sql("SELECT * FROM ta")
    checkAnswer(
      tableA.where('col7 === 1).orderBy('col2.asc).select('col4),
      Row(1) :: Nil)
    checkAnswer(
      tableA.where('col2 === 6).orderBy('col2.asc).select('col7),
      Row(-31) :: Nil)
  }

  test("metadata is propagated correctly") {
    val tableA = sql("SELECT col7, col1, col3 FROM ta")
    val schema = tableA.schema
    val docKey = "doc"
    val docValue = "first name"
    val metadata = new MetadataBuilder()
      .putString(docKey, docValue)
      .build()
    val schemaWithMeta = new StructType(Array(
      schema("col7"), schema("col1").copy(metadata = metadata), schema("col3")))
    val personWithMeta = createDataFrame(tableA.rdd, schemaWithMeta)
    def validateMetadata(rdd: DataFrame): Unit = {
      assert(rdd.schema("col1").metadata.getString(docKey) == docValue)
    }
    personWithMeta.createOrReplaceTempView("personWithMeta")
    validateMetadata(personWithMeta.select($"col1"))
    validateMetadata(personWithMeta.select($"col1"))
    validateMetadata(personWithMeta.select($"col7", $"col1"))
    validateMetadata(sql("SELECT * FROM personWithMeta"))
    validateMetadata(sql("SELECT col7, col1 FROM personWithMeta"))
    validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON col7 = personId"))
    validateMetadata(sql("SELECT col1, salary FROM personWithMeta JOIN salary ON col7 = personId"))
  }
} 
Example 79
Source File: SparkSessionUtils.scala    From mist   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}


object SparkSessionUtils {

  def getOrCreate(sc: SparkContext, withHiveSupport: Boolean): SparkSession = {
    val builder = SparkSession
      .builder()
      .sparkContext(sc)
      .config(sc.conf)

    if (withHiveSupport) {
      sc.conf.set(StaticSQLConf.CATALOG_IMPLEMENTATION.key, "hive")
      builder.enableHiveSupport().getOrCreate()
    } else builder.getOrCreate()

  }
} 
Example 80
Source File: QueryPartitionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import java.io.File
import java.sql.Timestamp

import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem

import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  import spark.implicits._

  test("SPARK-5068: query data when path doesn't exist") {
    withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) {
      val testData = sparkContext.parallelize(
        (1 to 10).map(i => TestData(i, i.toString))).toDF()
      testData.createOrReplaceTempView("testData")

      val tmpDir = Files.createTempDir()
      // create the table for test
      sql(s"CREATE TABLE table_with_partition(key int,value string) " +
        s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
        "SELECT key,value FROM testData")

      // test for the exist path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect
          ++ testData.toDF.collect ++ testData.toDF.collect)

      // delete the path of one partition
      tmpDir.listFiles
        .find { f => f.isDirectory && f.getName().startsWith("ds=") }
        .foreach { f => Utils.deleteRecursively(f) }

      // test for after delete the path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect)

      sql("DROP TABLE IF EXISTS table_with_partition")
      sql("DROP TABLE IF EXISTS createAndInsertTest")
    }
  }

  test("SPARK-21739: Cast expression should initialize timezoneId") {
    withTable("table_with_timestamp_partition") {
      sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
      sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
        "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")

      // test for Cast expression in TableReader
      checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
        Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))

      // test for Cast expression in HiveTableScanExec
      checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
        "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
    }
  }
} 
Example 81
Source File: TestHiveSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.test.{TestHiveSingleton, TestHiveSparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils


class TestHiveSuite extends TestHiveSingleton with SQLTestUtils {
  test("load test table based on case sensitivity") {
    val testHiveSparkSession = spark.asInstanceOf[TestHiveSparkSession]

    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
      sql("SELECT * FROM SRC").queryExecution.analyzed
      assert(testHiveSparkSession.getLoadedTables.contains("src"))
      assert(testHiveSparkSession.getLoadedTables.size == 1)
    }
    testHiveSparkSession.reset()

    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
      val err = intercept[AnalysisException] {
        sql("SELECT * FROM SRC").queryExecution.analyzed
      }
      assert(err.message.contains("Table or view not found"))
    }
    testHiveSparkSession.reset()
  }

  test("SPARK-15887: hive-site.xml should be loaded") {
    assert(hiveClient.getConf("hive.in.test", "") == "true")
  }
} 
Example 82
Source File: FiltersSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(Collections.singletonList(varCharCol))

  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")

  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")

  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")

  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning",
    (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) ::
      (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil,
    """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name) {
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") {
        val converted = shim.convertFilters(testTable, filters)
        if (converted != result) {
          fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'")
        }
      }
    }
  }

  test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") {
    import org.apache.spark.sql.catalyst.dsl.expressions._
    Seq(true, false).foreach { enabled =>
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) {
        val filters =
          (Literal(1) === a("intcol", IntegerType) ||
            Literal(2) === a("intcol", IntegerType)) :: Nil
        val converted = shim.convertFilters(testTable, filters)
        if (enabled) {
          assert(converted == "(1 = intcol or 2 = intcol)")
        } else {
          assert(converted.isEmpty)
        }
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 83
Source File: SparkSQLOperationManager.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver.server

import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation}
import org.apache.spark.sql.internal.SQLConf


private[thriftserver] class SparkSQLOperationManager()
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]()
  val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]()

  override def newExecuteStatementOperation(
      parentSession: HiveSession,
      statement: String,
      confOverlay: JMap[String, String],
      async: Boolean): ExecuteStatementOperation = synchronized {
    val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
    require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
      s" initialized or had already closed.")
    val conf = sqlContext.sessionState.conf
    val hiveSessionState = parentSession.getSessionState
    setConfMap(conf, hiveSessionState.getOverriddenConfigurations)
    setConfMap(conf, hiveSessionState.getHiveVariables)
    val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
    val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
      runInBackground)(sqlContext, sessionToActivePool)
    handleToOperation.put(operation.getHandle, operation)
    logDebug(s"Created Operation for $statement with session=$parentSession, " +
      s"runInBackground=$runInBackground")
    operation
  }

  def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
    val iterator = confMap.entrySet().iterator()
    while (iterator.hasNext) {
      val kv = iterator.next()
      conf.setConfString(kv.getKey, kv.getValue)
    }
  }
} 
Example 84
Source File: SubstituteUnresolvedOrdinals.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 85
Source File: ResolveInlineTables.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          val castedExpr = if (e.dataType.sameType(targetType)) {
            e
          } else {
            cast(e, targetType)
          }
          castedExpr.eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 86
Source File: StatsEstimationTestBase.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen + 8 + 4
    case _ => colStat.avgLen
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 87
Source File: AnalysisTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
    catalog.createDatabase(
      CatalogDatabase("default", "", new URI("loc"), Map.empty),
      ignoreIfExists = false)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
    catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.executeAndCheck(inputPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected override def comparePlans(
      plan1: LogicalPlan,
      plan2: LogicalPlan,
      checkAnalysis: Boolean = false): Unit = {
    // Analysis tests may have not been fully resolved, so skip checkAnalysis.
    super.comparePlans(plan1, plan2, checkAnalysis)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall(
        e.getMessage.toLowerCase(Locale.ROOT).contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 88
Source File: SubstituteUnresolvedOrdinalsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.internal.SQLConf

class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
  private lazy val a = testRelation2.output(0)
  private lazy val b = testRelation2.output(1)

  test("unresolved ordinal should not be unresolved") {
    // Expression OrderByOrdinal is unresolved.
    assert(!UnresolvedOrdinal(0).resolved)
  }

  test("order by ordinal") {
    // Tests order by ordinal, apply single rule.
    val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf).apply(plan),
      testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))

    // Tests order by ordinal, do full analysis
    checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc))

    // order by ordinal can be turned off by config
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan),
      testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
  }

  test("group by ordinal") {
    // Tests group by ordinal, apply single rule.
    val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf).apply(plan2),
      testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b))

    // Tests group by ordinal, do full analysis
    checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b))

    // group by ordinal can be turned off by config
    comparePlans(
      new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2),
      testRelation2.groupBy(Literal(1), Literal(2))('a, 'b))
  }
} 
Example 89
Source File: OptimizerStructuralIntegrityCheckerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf


class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {

  object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
      case Project(projectList, child) =>
        val newAttr = UnresolvedAttribute("unresolvedAttr")
        Project(projectList ++ Seq(newAttr), child)
    }
  }

  object Optimize extends Optimizer(
    new SessionCatalog(
      new InMemoryCatalog,
      EmptyFunctionRegistry,
      new SQLConf())) {
    val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
    override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches
  }

  test("check for invalid plan after execution of rule") {
    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
    assert(analyzed.resolved)
    val message = intercept[TreeNodeException[LogicalPlan]] {
      Optimize.execute(analyzed)
    }.getMessage
    val ruleName = OptimizeRuleBreakSI.ruleName
    assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
    assert(message.contains("the structural integrity of the plan is broken"))
  }
} 
Example 90
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 91
Source File: EliminateSortsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}

class EliminateSortsSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Eliminate Sorts", FixedPoint(10),
        FoldablePropagation,
        EliminateSorts) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Empty order by clause") {
    val x = testRelation

    val query = x.orderBy()
    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = x.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("All the SortOrder are no-op") {
    val x = testRelation

    val query = x.orderBy(SortOrder(3, Ascending), SortOrder(-1, Ascending))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(x)

    comparePlans(optimized, correctAnswer)
  }

  test("Partial order-by clauses contain no-op SortOrder") {
    val x = testRelation

    val query = x.orderBy(SortOrder(3, Ascending), 'a.asc)
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(x.orderBy('a.asc))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove no-op alias") {
    val x = testRelation

    val query = x.select('a.as('x), Year(CurrentDate()).as('y), 'b)
      .orderBy('x.asc, 'y.asc, 'b.desc)
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(
      x.select('a.as('x), Year(CurrentDate()).as('y), 'b).orderBy('x.asc, 'b.desc))

    comparePlans(optimized, correctAnswer)
  }
} 
Example 92
Source File: AggregateOptimizeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 93
Source File: SparkPlanner.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val experimentalMethods: ExperimentalMethods)
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  override def strategies: Seq[Strategy] =
    experimentalMethods.extraStrategies ++
      extraPlanningStrategies ++ (
      DataSourceV2Strategy ::
      FileSourceStrategy ::
      DataSourceStrategy(conf) ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

  
  def pruneFilterProject(
      projectList: Seq[NamedExpression],
      filterPredicates: Seq[Expression],
      prunePushedDownFilters: Seq[Expression] => Seq[Expression],
      scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {

    val projectSet = AttributeSet(projectList.flatMap(_.references))
    val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
    val filterCondition: Option[Expression] =
      prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)

    // Right now we still use a projection even if the only evaluation is applying an alias
    // to a column.  Since this is a no-op, it could be avoided. However, using this
    // optimization with the current implementation would change the output schema.
    // TODO: Decouple final output schema from expression evaluation so this copy can be
    // avoided safely.

    if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
        filterSet.subsetOf(projectSet)) {
      // When it is possible to just use column pruning to get the right projection and
      // when the columns of this projection are enough to evaluate all filter conditions,
      // just do a scan followed by a filter, with no extra project.
      val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]])
      filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
    } else {
      val scan = scanBuilder((projectSet ++ filterSet).toSeq)
      ProjectExec(projectList, filterCondition.map(FilterExec(_, scan)).getOrElse(scan))
    }
  }
} 
Example 94
Source File: OrcOptions.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.orc

import java.util.Locale

import org.apache.orc.OrcConf.COMPRESS

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val compressionCodec: String = {
    // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec`
    // are in order of precedence from highest to lowest.
    val orcCompressionConf = parameters.get(COMPRESS.getAttribute)
    val codecName = parameters
      .get("compression")
      .orElse(orcCompressionConf)
      .getOrElse(sqlConf.orcCompressionCodec)
      .toLowerCase(Locale.ROOT)
    if (!shortOrcCompressionCodecNames.contains(codecName)) {
      val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT))
      throw new IllegalArgumentException(s"Codec [$codecName] " +
        s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
    }
    shortOrcCompressionCodecNames(codecName)
  }
}

object OrcOptions {
  // The ORC compression short names
  private val shortOrcCompressionCodecNames = Map(
    "none" -> "NONE",
    "uncompressed" -> "NONE",
    "snappy" -> "SNAPPY",
    "zlib" -> "ZLIB",
    "lzo" -> "LZO")

  def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name)
} 
Example 95
Source File: DataSourceV2Utils.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import java.util.regex.Pattern

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport}

private[sql] object DataSourceV2Utils extends Logging {

  
  def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match {
    case cs: SessionConfigSupport =>
      val keyPrefix = cs.keyPrefix()
      require(keyPrefix != null, "The data source config key prefix can't be null.")

      val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)")

      conf.getAllConfs.flatMap { case (key, value) =>
        val m = pattern.matcher(key)
        if (m.matches() && m.groupCount() > 0) {
          Seq((m.group(1), value))
        } else {
          Seq.empty
        }
      }

    case _ => Map.empty
  }
} 
Example 96
Source File: SQLHadoopMapReduceCommitProtocol.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql.internal.SQLConf


class SQLHadoopMapReduceCommitProtocol(
    jobId: String,
    path: String,
    dynamicPartitionOverwrite: Boolean = false)
  extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite)
    with Serializable with Logging {

  override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
    var committer = super.setupCommitter(context)

    val configuration = context.getConfiguration
    val clazz =
      configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])

    if (clazz != null) {
      logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")

      // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
      // has an associated output committer. To override this output committer,
      // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
      // If a data source needs to override the output committer, it needs to set the
      // output committer in prepareForWrite method.
      if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) {
        // The specified output committer is a FileOutputCommitter.
        // So, we will use the FileOutputCommitter-specified constructor.
        val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
        committer = ctor.newInstance(new Path(path), context)
      } else {
        // The specified output committer is just an OutputCommitter.
        // So, we will use the no-argument constructor.
        val ctor = clazz.getDeclaredConstructor()
        committer = ctor.newInstance()
      }
    }
    logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}")
    committer
  }
} 
Example 97
Source File: ParquetOptions.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.util.Locale

import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.parquet.hadoop.metadata.CompressionCodecName

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


  val mergeSchema: Boolean = parameters
    .get(MERGE_SCHEMA)
    .map(_.toBoolean)
    .getOrElse(sqlConf.isParquetSchemaMergingEnabled)
}


object ParquetOptions {
  val MERGE_SCHEMA = "mergeSchema"

  // The parquet compression short names
  private val shortParquetCompressionCodecNames = Map(
    "none" -> CompressionCodecName.UNCOMPRESSED,
    "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
    "snappy" -> CompressionCodecName.SNAPPY,
    "gzip" -> CompressionCodecName.GZIP,
    "lzo" -> CompressionCodecName.LZO)

  def getParquetCompressionCodecName(name: String): String = {
    shortParquetCompressionCodecNames(name).name()
  }
} 
Example 98
Source File: SaveIntoDataSourceCommand.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider


case class SaveIntoDataSourceCommand(
    query: LogicalPlan,
    dataSource: CreatableRelationProvider,
    options: Map[String, String],
    mode: SaveMode) extends RunnableCommand {

  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

  override def run(sparkSession: SparkSession): Seq[Row] = {
    dataSource.createRelation(
      sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))

    Seq.empty[Row]
  }

  override def simpleString: String = {
    val redacted = SQLConf.get.redactOptions(options)
    s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
  }
} 
Example 99
Source File: Exchange.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

  def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.exchangeReuseEnabled) {
      return plan
    }
    // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
    val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
    plan.transformUp {
      case exchange: Exchange =>
        // the exchanges that have same results usually also have same schemas (same column names).
        val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
        val samePlan = sameSchema.find { e =>
          exchange.sameResult(e)
        }
        if (samePlan.isDefined) {
          // Keep the output of this exchange, the following plans require that to resolve
          // attributes.
          ReusedExchangeExec(exchange.output, samePlan.get)
        } else {
          sameSchema += exchange
          exchange
        }
    }
  }
} 
Example 100
Source File: FileStreamSinkLog.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.net.URI

import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf


class FileStreamSinkLog(
    metadataLogVersion: Int,
    sparkSession: SparkSession,
    path: String)
  extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) {

  private implicit val formats = Serialization.formats(NoTypeHints)

  protected override val fileCleanupDelayMs = sparkSession.sessionState.conf.fileSinkLogCleanupDelay

  protected override val isDeletingExpiredLog = sparkSession.sessionState.conf.fileSinkLogDeletion

  protected override val defaultCompactInterval =
    sparkSession.sessionState.conf.fileSinkLogCompactInterval

  require(defaultCompactInterval > 0,
    s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $defaultCompactInterval) " +
      "to a positive value.")

  override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = {
    val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet
    if (deletedFiles.isEmpty) {
      logs
    } else {
      logs.filter(f => !deletedFiles.contains(f.path))
    }
  }
}

object FileStreamSinkLog {
  val VERSION = 1
  val DELETE_ACTION = "delete"
  val ADD_ACTION = "add"
} 
Example 101
Source File: ConfigBehaviorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.commons.math3.stat.inference.ChiSquareTest

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {

  import testImplicits._

  test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") {
    // In this test, we run a sort and compute the histogram for partition size post shuffle.
    // With a high sample count, the partition size should be more evenly distributed, and has a
    // low chi-sq test value.
    // Also the whole code path for range partitioning as implemented should be deterministic
    // (it uses the partition id as the seed), so this test shouldn't be flaky.

    val numPartitions = 4

    def computeChiSquareTest(): Double = {
      val n = 10000
      // Trigger a sort
      val data = spark.range(0, n, 1, 1).sort('id)
        .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect()

      // Compute histogram for the number of records per partition post sort
      val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray
      assert(dist.length == 4)

      new ChiSquareTest().chiSquare(
        Array.fill(numPartitions) { n.toDouble / numPartitions },
        dist)
    }

    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
      // The default chi-sq value should be low
      assert(computeChiSquareTest() < 100)

      withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") {
        // If we only sample one point, the range boundaries will be pretty bad and the
        // chi-sq value would be very high.
        assert(computeChiSquareTest() > 300)
      }
    }
  }

} 
Example 102
Source File: ParquetFileFormatSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLContext {

  test("read parquet footers in parallel") {
    def testReadFooters(ignoreCorruptFiles: Boolean): Unit = {
      withTempDir { dir =>
        val fs = FileSystem.get(sparkContext.hadoopConfiguration)
        val basePath = dir.getCanonicalPath

        val path1 = new Path(basePath, "first")
        val path2 = new Path(basePath, "second")
        val path3 = new Path(basePath, "third")

        spark.range(1).toDF("a").coalesce(1).write.parquet(path1.toString)
        spark.range(1, 2).toDF("a").coalesce(1).write.parquet(path2.toString)
        spark.range(2, 3).toDF("a").coalesce(1).write.json(path3.toString)

        val fileStatuses =
          Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten

        val footers = ParquetFileFormat.readParquetFootersInParallel(
          sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles)

        assert(footers.size == 2)
      }
    }

    testReadFooters(true)
    val exception = intercept[java.io.IOException] {
      testReadFooters(false)
    }
    assert(exception.getMessage().contains("Could not read footer for file"))
  }
} 
Example 103
Source File: DataSourceScanExecRedactionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {

  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.redaction.string.regex", "file:/[\\w_]+")

  test("treeString is redacted") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)

      val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
        .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head
      assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/")))

      assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.toString.contains(rootPath.getName))
      assert(!df.queryExecution.simpleString.contains(rootPath.getName))

      val replacement = "*********"
      assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.toString.contains(replacement))
      assert(df.queryExecution.simpleString.contains(replacement))
    }
  }

  private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = {
    queryExecution.toString.contains(msg) ||
    queryExecution.simpleString.contains(msg) ||
    queryExecution.stringWithStats.contains(msg)
  }

  test("explain is redacted using SQLConf") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)
      val replacement = "*********"

      // Respect SparkConf and replace file:/
      assert(isIncluded(df.queryExecution, replacement))

      assert(isIncluded(df.queryExecution, "FileScan"))
      assert(!isIncluded(df.queryExecution, "file:/"))

      withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") {
        // Respect SQLConf and replace FileScan
        assert(isIncluded(df.queryExecution, replacement))

        assert(!isIncluded(df.queryExecution, "FileScan"))
        assert(isIncluded(df.queryExecution, "file:/"))
      }
    }
  }

} 
Example 104
Source File: DataSourceV2UtilsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.v2

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.internal.SQLConf

class DataSourceV2UtilsSuite extends SparkFunSuite {

  private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix

  test("method withSessionConfig() should propagate session configs correctly") {
    // Only match configs with keys start with "spark.datasource.${keyPrefix}".
    val conf = new SQLConf
    conf.setConfString(s"spark.datasource.$keyPrefix.foo.bar", "false")
    conf.setConfString(s"spark.datasource.$keyPrefix.whateverConfigName", "123")
    conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false")
    conf.setConfString("spark.datasource.another.config.name", "123")
    conf.setConfString(s"spark.datasource.$keyPrefix.", "123")
    val cs = classOf[DataSourceV2WithSessionConfig].newInstance()
    val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf)
    assert(confs.size == 2)
    assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0)
    assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0)
    assert(confs.keySet.contains("foo.bar"))
    assert(confs.keySet.contains("whateverConfigName"))
  }
}

class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport {

  override def keyPrefix: String = "userDefinedDataSource"
} 
Example 105
Source File: DataSourceTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] abstract class DataSourceTest extends QueryTest {

  protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) {
    test(sqlString) {
      withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) {
        checkAnswer(spark.sql(sqlString), expectedAnswer)
      }
    }
  }

}

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(
      parameters("from").toInt,
      parameters("TO").toInt,
      parameters("Table"))(sqlContext.sparkSession)
  }
}

case class SimpleDDLScan(
    from: Int,
    to: Int,
    table: String)(@transient val sparkSession: SparkSession)
  extends BaseRelation with TableScan {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sparkSession.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
} 
Example 106
Source File: SharedSparkSession.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import scala.concurrent.duration._

import org.scalatest.{BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.internal.SQLConf


  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_spark != null) {
          try {
            _spark.sessionState.catalog.reset()
          } finally {
            _spark.stop()
            _spark = null
          }
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(10.seconds), interval(2.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 107
Source File: TestSQLContext.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf}


  val overrideConfs: Map[String, String] =
    Map(
      // Fewer shuffle partitions to speed up testing.
      SQLConf.SHUFFLE_PARTITIONS.key -> "5")
}

private[sql] class TestSQLSessionStateBuilder(
    session: SparkSession,
    state: Option[SessionState])
  extends SessionStateBuilder(session, state) with WithTestConf {
  override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs
  override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _)
} 
Example 108
Source File: StarryJoinLocalStrategy.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, StarryHashJoinExec, StarryNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf


  private def canRunInLocalMemory(plan: LogicalPlan) = {
    plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.getConfString("spark.sql.maxLocalMemoryJoin", "10485760").toLong
  }

  private def canBuildRight(joinType: JoinType): Boolean = joinType match {
    case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
    case _ => false
  }

  private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
    case _: InnerLike | RightOuter => true
    case _ => false
  }


  def decideBuildSide(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) = {
    val buildLeft = canBuildLeft(joinType) && canRunInLocalMemory(left)
    val buildRight = canBuildRight(joinType) && canRunInLocalMemory(right)

    def smallerSide =
      if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft

    if (buildRight && buildLeft) {
      smallerSide
    } else if (buildRight) {
      BuildRight
    } else if (buildLeft) {
      BuildLeft
    } else {
      smallerSide
    }
  }

  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
      val buildSide = decideBuildSide(joinType, left, right)
      Seq(StarryHashJoinExec(
        leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))

    // --- SortMergeJoin ------------------------------------------------------------

    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
      if RowOrdering.isOrderable(leftKeys) =>
      joins.SortMergeJoinExec(
        leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil

    // --- Without joining keys ------------------------------------------------------------

    // Pick BroadcastNestedLoopJoin if one side could be broadcast
    case [email protected](left, right, joinType, condition) =>
      val buildSide = decideBuildSide(joinType, left, right)
      StarryNestedLoopJoinExec(
        planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
    // Pick CartesianProduct for InnerJoin
    case logical.Join(left, right, _: InnerLike, condition) =>
      joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
    case _ => Nil
  }
} 
Example 109
Source File: HiveTestTrait.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.io.File

import com.cloudera.spark.cloud.ObjectStoreConfigurations
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.sql.hive.test.TestHiveContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


trait HiveTestTrait extends SparkFunSuite with BeforeAndAfterAll {
//  override protected val enableAutoThreadAudit = false
  protected var hiveContext: HiveInstanceForTests = _
  protected var spark: SparkSession = _


  protected override def beforeAll(): Unit = {
    super.beforeAll()
    // set up spark and hive context
    hiveContext = new HiveInstanceForTests()
    spark = hiveContext.sparkSession
  }

  protected override def afterAll(): Unit = {
    try {
      SparkSession.clearActiveSession()

      if (hiveContext != null) {
        hiveContext.reset()
        hiveContext = null
      }
      if (spark != null) {
        spark.close()
        spark = null
      }
    } finally {
      super.afterAll()
    }
  }

}

class HiveInstanceForTests
  extends TestHiveContext(
    new SparkContext(
      System.getProperty("spark.sql.test.master", "local[1]"),
      "TestSQLContext",
      new SparkConf()
        .setAll(ObjectStoreConfigurations.RW_TEST_OPTIONS)
        .set("spark.sql.warehouse.dir",
          TestSetup.makeWarehouseDir().toURI.getPath)
    )
  ) {

}




object TestSetup {

  def makeWarehouseDir(): File = {
    val warehouseDir = Utils.createTempDir(namePrefix = "warehouse")
    warehouseDir.delete()
    warehouseDir
  }
} 
Example 110
Source File: SparkSQLOperationsMenager.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver.server

import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SQLContext, SequilaSession}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperationSeq}
import org.apache.spark.sql.internal.SQLConf


private[thriftserver] class SparkSQLOperationManagerSeq(ss: SequilaSession)
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]()
  val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]()

  override def newExecuteStatementOperation(
                                             parentSession: HiveSession,
                                             statement: String,
                                             confOverlay: JMap[String, String],
                                             async: Boolean): ExecuteStatementOperation = synchronized {
    val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
    require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
      s" initialized or had already closed.")
    val conf = sqlContext.sessionState.conf
    val hiveSessionState = parentSession.getSessionState
    setConfMap(conf, hiveSessionState.getOverriddenConfigurations)
    setConfMap(conf, hiveSessionState.getHiveVariables)
    val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
    val operation = new SparkExecuteStatementOperationSeq(parentSession, statement, confOverlay,
      runInBackground)(ss, sessionToActivePool)
    handleToOperation.put(operation.getHandle, operation)
    logDebug(s"Created Operation for $statement with session=$parentSession, " +
      s"runInBackground=$runInBackground")
    operation
  }

  def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
    val iterator = confMap.entrySet().iterator()
    while (iterator.hasNext) {
      val kv = iterator.next()
      conf.setConfString(kv.getKey, kv.getValue)
    }
  }
} 
Example 111
Source File: GenomicInterval.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package  org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Range, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.biodatageeks.sequila.utils.Columns

case class GenomicInterval(
                            contig:String,
                            start:Int,
                            end:Int,
                            output: Seq[Attribute]
                          )  extends LeafNode with MultiInstanceRelation with Serializable {

  override def newInstance(): GenomicInterval = copy(output = output.map(_.newInstance()))

  def computeStats(conf: SQLConf): Statistics = {
    val sizeInBytes = IntegerType.defaultSize * 2 //FIXME: Add contigName size
    Statistics( sizeInBytes = sizeInBytes )
  }

  override def simpleString: String = {
    s"GenomicInterval ($contig, $start, $end)"
  }

}

object GenomicInterval {
  def apply(contig:String, start: Int, end: Int): GenomicInterval = {
    val output = StructType(Seq(
      StructField(s"${Columns.CONTIG}", StringType, nullable = false),
      StructField(s"${Columns.START}", IntegerType, nullable = false),
      StructField(s"${Columns.END}", IntegerType, nullable = false))
    )  .toAttributes
    new GenomicInterval(contig,start, end, output)
  }
} 
Example 112
Source File: SeQuiLaAnalyzer.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.ResolveTableValuedFunctionsSeq
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

import scala.util.Random


class SeQuiLaAnalyzer(catalog: SessionCatalog, conf: SQLConf) extends Analyzer(catalog, conf, conf.optimizerMaxIterations){
  //override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(ResolveTableValuedFunctionsSeq)


  //  override lazy val batches: Seq[Batch] = Seq(
  //    Batch("Custeom", fixedPoint, ResolveTableValuedFunctionsSeq),
  //    Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf),
  //      ResolveHints.RemoveAllHints))


  var sequilaOptmazationRules: Seq[Rule[LogicalPlan]] = Nil

  override lazy val batches: Seq[Batch] = Seq(
    Batch("Hints", fixedPoint,
      new ResolveHints.ResolveBroadcastHints(conf),
      ResolveHints.RemoveAllHints),
    Batch("Simple Sanity Check", Once,
      LookupFunctions),
    Batch("Substitution", fixedPoint,
      CTESubstitution,
      WindowsSubstitution,
      EliminateUnions,
      new SubstituteUnresolvedOrdinals(conf)),
    Batch("Resolution", fixedPoint,
      ResolveTableValuedFunctionsSeq ::
      ResolveRelations ::
        ResolveReferences ::
        ResolveCreateNamedStruct ::
        ResolveDeserializer ::
        ResolveNewInstance ::
        ResolveUpCast ::
        ResolveGroupingAnalytics ::
        ResolvePivot ::
        ResolveOrdinalInOrderByAndGroupBy ::
        ResolveAggAliasInGroupBy ::
        ResolveMissingReferences ::
        ExtractGenerator ::
        ResolveGenerate ::
        ResolveFunctions ::
        ResolveAliases ::
        ResolveSubquery ::
        ResolveSubqueryColumnAliases ::
        ResolveWindowOrder ::
        ResolveWindowFrame ::
        ResolveNaturalAndUsingJoin ::

        ExtractWindowExpressions ::
        GlobalAggregates ::
        ResolveAggregateFunctions ::
        TimeWindowing ::
        ResolveInlineTables(conf) ::
        ResolveTimeZone(conf) ::
        TypeCoercion.typeCoercionRules(conf) ++
          extendedResolutionRules : _*),
    Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
    Batch("SeQuiLa", Once,sequilaOptmazationRules: _*), //SeQuilaOptimization rules
    Batch("View", Once,
      AliasViewChild(conf)),
    Batch("Nondeterministic", Once,
      PullOutNondeterministic),
    Batch("UDF", Once,
      HandleNullInputsForUDF),
    Batch("FixNullability", Once,
      FixNullability),
    Batch("Subquery", Once,
      UpdateOuterReferences),
    Batch("Cleanup", fixedPoint,
      CleanupAliases)
  )



} 
Example 113
Source File: IntervalTreeJoinOptim.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.IntervalTree

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, _}
import org.apache.spark.sql.internal.SQLConf

@DeveloperApi
case class IntervalTreeJoinOptim(left: SparkPlan,
                                 right: SparkPlan,
                                 condition: Seq[Expression],
                                 context: SparkSession,leftLogicalPlan: LogicalPlan, righLogicalPlan: LogicalPlan) extends BinaryExecNode {
  def output = left.output ++ right.output

  lazy val (buildPlan, streamedPlan) = (left, right)

  lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1)),
    List(condition(2), condition(3)))

  @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output)
  @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys,
    streamedPlan.output)

  protected override def doExecute(): RDD[InternalRow] = {
    val v1 = left.execute()
    val v1kv = v1.map(x => {
      val v1Key = buildKeyGenerator(x)

      (new IntervalWithRow[Int](v1Key.getInt(0), v1Key.getInt(1),
        x) )

    })
    val v2 = right.execute()
    val v2kv = v2.map(x => {
      val v2Key = streamKeyGenerator(x)
      (new IntervalWithRow[Int](v2Key.getInt(0), v2Key.getInt(1),
        x) )
    })
    

    val conf = new SQLConf()
    val v1Size =
      if(leftLogicalPlan
      .stats
      .sizeInBytes >0) leftLogicalPlan.stats.sizeInBytes.toLong
      else
        v1.count

    val v2Size = if(righLogicalPlan
      .stats
      .sizeInBytes >0) righLogicalPlan.stats.sizeInBytes.toLong
    else
      v2.count
    if ( v1Size <= v2Size ) {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v1kv, v2kv,v1.count())
     v3.mapPartitions(
       p => {
         val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
         p.map(r=>joiner.join(r._1.asInstanceOf[UnsafeRow],r._2.asInstanceOf[UnsafeRow]))
       }



     )

    }
    else {
      val v3 = IntervalTreeJoinOptimImpl.overlapJoin(context.sparkContext, v2kv, v1kv, v2.count())
      v3.mapPartitions(
        p => {
          val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
          p.map(r=>joiner.join(r._2.asInstanceOf[UnsafeRow],r._1.asInstanceOf[UnsafeRow]))
        }

      )
    }

  }
} 
Example 114
Source File: SharedSparkSession.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSpec, Suite}
import org.scalatest.concurrent.Eventually

trait SharedSparkSession
    extends FunSpec
    with BeforeAndAfterEach
    with BeforeAndAfterAll
    with Eventually { self: Suite =>

  
  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
  }
} 
Example 115
Source File: SharedSparkSession.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.common.utilities.spark

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

trait SharedSparkSession
    extends BeforeAndAfterEach
    with BeforeAndAfterAll
    with Eventually { self: Suite =>

  protected val additionalConfig: Map[String, String] = Map.empty

  
  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
  }

  // Mocks data for testing
  def mockDataInDataFrame(numberOfRows: Int): DataFrame = {
    def stringed(n: Int) = s"""{"id": "$n","name": "MAC-$n", "address": "MAC-${n + 1}", "age": "${n + 1}", "company": "MAC-$n", "designation": "MAC-$n", "salary": "${n * 10000}" }"""
    val texts: Seq[String] = (1 to numberOfRows).map { x => stringed(x) }
    val rdd: RDD[String] = spark.sparkContext.parallelize(texts)
    val dataFrame: DataFrame = spark.read.json(rdd)
    dataFrame
  }
} 
Example 116
Source File: DataTypeUtil.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

object DataTypeUtil {
  def sameType(left: DataType, right: DataType): Boolean =
    if (SQLConf.get.caseSensitiveAnalysis) {
      equalsIgnoreNullability(left, right)
    } else {
      equalsIgnoreCaseAndNullability(left, right)
    }

  private def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
    (left, right) match {
      case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
        equalsIgnoreNullability(leftElementType, rightElementType)
      case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
        equalsIgnoreNullability(leftKeyType, rightKeyType) &&
          equalsIgnoreNullability(leftValueType, rightValueType)
      case (StructType(leftFields), StructType(rightFields)) =>
        leftFields.length == rightFields.length &&
          leftFields.zip(rightFields).forall { case (l, r) =>
            l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
          }
      case (l, r) => l == r
    }
  }

  private def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = {
    (from, to) match {
      case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
        equalsIgnoreCaseAndNullability(fromElement, toElement)

      case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
        equalsIgnoreCaseAndNullability(fromKey, toKey) &&
          equalsIgnoreCaseAndNullability(fromValue, toValue)

      case (StructType(fromFields), StructType(toFields)) =>
        fromFields.length == toFields.length &&
          fromFields.zip(toFields).forall { case (l, r) =>
            l.name.equalsIgnoreCase(r.name) &&
              equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
          }

      case (fromDataType, toDataType) => fromDataType == toDataType
    }
  }
} 
Example 117
Source File: CSVOutputFormatter.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.pipe

import java.io.InputStream

import scala.collection.JavaConverters._

import com.univocity.parsers.csv.CsvParser
import org.apache.commons.io.IOUtils
import org.apache.spark.sql.execution.datasources.csv.{CSVDataSourceUtils, CSVUtils, UnivocityParserUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import io.projectglow.SparkShim.{CSVOptions, UnivocityParser}

class CSVOutputFormatter(parsedOptions: CSVOptions) extends OutputFormatter {

  private def getSchema(record: Array[String]): StructType = {
    val header =
      CSVDataSourceUtils.makeSafeHeader(
        record,
        SQLConf.get.caseSensitiveAnalysis,
        parsedOptions
      )
    val fields = header.map { fieldName =>
      StructField(fieldName, StringType, nullable = true)
    }
    StructType(fields)
  }

  
  override def makeIterator(stream: InputStream): Iterator[Any] = {
    val lines = IOUtils.lineIterator(stream, "UTF-8").asScala
    val filteredLines = CSVUtils.filterCommentAndEmpty(lines, parsedOptions)

    if (filteredLines.isEmpty) {
      return Iterator.empty
    }

    val firstLine = filteredLines.next
    val csvParser = new CsvParser(parsedOptions.asParserSettings)
    val firstRecord = csvParser.parseLine(firstLine)
    val schema = getSchema(firstRecord)
    val univocityParser = new UnivocityParser(schema, schema, parsedOptions)

    val parsedIter =
      UnivocityParserUtils.parseIterator(
        Iterator(firstLine) ++ filteredLines,
        univocityParser,
        schema
      )

    val parsedIterWithoutHeader = if (parsedOptions.headerFlag) {
      parsedIter.drop(1)
    } else {
      parsedIter
    }

    Iterator(schema) ++ parsedIterWithoutHeader.map(_.copy)
  }
}

class CSVOutputFormatterFactory extends OutputFormatterFactory {
  override def name: String = "csv"

  override def makeOutputFormatter(
      options: Map[String, String]
  ): OutputFormatter = {
    val parsedOptions =
      new CSVOptions(
        options,
        SQLConf.get.csvColumnPruning,
        SQLConf.get.sessionLocalTimeZone
      )
    new CSVOutputFormatter(parsedOptions)
  }
} 
Example 118
Source File: implicits.scala    From spark-states   with Apache License 2.0 5 votes vote down vote up
package ru.chermenin.spark.sql.execution.streaming.state

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.SparkSession.Builder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import ru.chermenin.spark.sql.execution.streaming.state.RocksDbStateStoreProvider._

import scala.collection.mutable



object implicits extends Serializable {

  implicit class SessionImplicits(sparkSessionBuilder: Builder) {

    def useRocksDBStateStore(): Builder =
      sparkSessionBuilder.config(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
        classOf[RocksDbStateStoreProvider].getCanonicalName)

  }

  implicit class WriterImplicits[T](dsw: DataStreamWriter[T]) {

    def stateTimeout(runtimeConfig: RuntimeConfig,
                     queryName: String = "",
                     expirySecs: Int = DEFAULT_STATE_EXPIRY_SECS.toInt,
                     checkpointLocation: String = ""): DataStreamWriter[T] = {

      val extraOptions = getExtraOptions
      val name = queryName match {
        case "" | null => extraOptions.getOrElse("queryName", UNNAMED_QUERY)
        case _ => queryName
      }

      val location = new Path(checkpointLocation match {
        case "" | null =>
          extraOptions.getOrElse("checkpointLocation",
            runtimeConfig.getOption(SQLConf.CHECKPOINT_LOCATION.key
            ).getOrElse(throw new IllegalStateException(
              "Checkpoint Location must be specified for State Expiry either " +
                """through option("checkpointLocation", ...) or """ +
                s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)"""))
          )
        case _ => checkpointLocation
      }, name)
        .toUri.toString

      runtimeConfig.set(s"$STATE_EXPIRY_SECS.$name", if (expirySecs < 0) -1 else expirySecs)

      dsw
        .queryName(name)
        .option("checkpointLocation", location)
    }

    private def getExtraOptions: mutable.HashMap[String, String] = {
      val className = classOf[DataStreamWriter[T]]
      val field = className.getDeclaredField("extraOptions")
      field.setAccessible(true)

      field.get(dsw).asInstanceOf[mutable.HashMap[String, String]]
    }
  }

} 
Example 119
Source File: SimbaOptimizer.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkOptimizer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.simba.plans.SpatialJoin


class SimbaOptimizer(catalog: SessionCatalog,
                     conf: SQLConf,
                     experimentalMethods: ExperimentalMethods)
 extends SparkOptimizer(catalog, conf, experimentalMethods) {
  override def batches: Seq[Batch] = super.batches :+
    Batch("SpatialJoinPushDown", FixedPoint(100), PushPredicateThroughSpatialJoin)
}

object PushPredicateThroughSpatialJoin extends Rule[LogicalPlan] with PredicateHelper {
  private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
    val (leftEvaluateCondition, rest) =
      condition.partition(_.references subsetOf left.outputSet)
    val (rightEvaluateCondition, commonCondition) =
      rest.partition(_.references subsetOf right.outputSet)

    (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // push the where condition down into join filter
    case f @ Filter(filterCondition, SpatialJoin(left, right, joinType, joinCondition)) =>
      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
        split(splitConjunctivePredicates(filterCondition), left, right)

      val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
      SpatialJoin(newLeft, newRight, joinType, newJoinCond)

    // push down the join filter into sub query scanning if applicable
    case f @ SpatialJoin(left, right, joinType, joinCondition) =>
      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
        split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)

      val newLeft = leftJoinConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightJoinConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = commonJoinCondition.reduceLeftOption(And)

      SpatialJoin(newLeft, newRight, joinType, newJoinCond)
  }
} 
Example 120
Source File: FrontendService.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service

import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.ChannelInitializer
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.logging.{LoggingHandler, LogLevel}

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.server.SQLServerConf._

private[service] abstract class FrontendService extends CompositeService {

  var port: Int = _
  var workerThreads: Int = _
  var bossGroup: NioEventLoopGroup = _
  var workerGroup: NioEventLoopGroup = _

  def messageHandler: ChannelInitializer[SocketChannel]

  override def doInit(conf: SQLConf): Unit = {
    port = conf.sqlServerPort
    workerThreads = conf.sqlServerWorkerThreads
    bossGroup = new NioEventLoopGroup(1)
    workerGroup = new NioEventLoopGroup(workerThreads)
  }

  override def doStart(): Unit = {
    try {
      val b = new ServerBootstrap()
        // .option(ChannelOption.SO_KEEPALIVE, true)
        .group(bossGroup, workerGroup)
        .channel(classOf[NioServerSocketChannel])
        .handler(new LoggingHandler(LogLevel.INFO))
        .childHandler(messageHandler)

      // Binds and starts to accept incoming connections
      val f = b.bind(port).sync()

      // Blocked until the server socket is closed
      logInfo(s"Start running the SQL server (port=$port, workerThreads=$workerThreads)")
      f.channel().closeFuture().sync();
    } finally {
      bossGroup.shutdownGracefully()
      workerGroup.shutdownGracefully()
    }
  }
} 
Example 121
Source File: Services.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf

abstract class Service extends Logging {

  def init(conf: SQLConf): Unit
  def start(): Unit
  def stop(): Unit
}

abstract class CompositeService extends Service {

  protected val services = new mutable.ArrayBuffer[Service]()

  protected def addService(service: Service): Unit = services += service

  // Initializes services in a bottom-up way
  final override def init(conf: SQLConf): Unit = {
    services.foreach(_.init(conf))
    doInit(conf)
  }

  // Starts services in a bottom-up way
  final override def start(): Unit = {
    services.foreach(_.start())
    doStart()
  }

  // Stops services in a top-down way
  final override def stop(): Unit = {
    doStop()
    services.foreach(_.stop())
  }

  def doInit(conf: SQLConf): Unit = {}
  def doStart(): Unit = {}
  def doStop(): Unit = {}
} 
Example 122
Source File: SQLServerEnv.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.server.ui.SQLServerTab
import org.apache.spark.util.Utils

object SQLServerEnv extends Logging {

  // For test use
  private var _sqlContext: Option[SQLContext] = None

  @DeveloperApi
  def withSQLContext(sqlContext: SQLContext): Unit = {
    require(sqlContext != null)
    _sqlContext = Option(sqlContext)
    sqlServListener
    uiTab
  }

  private def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
    sparkConf.getAll.foreach { case (k, v) =>
      sqlConf.setConfString(k, v)
    }
  }

  lazy val sparkConf: SparkConf = _sqlContext.map(_.sparkContext.conf).getOrElse {
    val sparkConf = new SparkConf(loadDefaults = true)

    // If user doesn't specify the appName, we want to get [SparkSQL::localHostName]
    // instead of the default appName [SQLServer].
    val maybeAppName = sparkConf
      .getOption("spark.app.name")
      .filterNot(_ == classOf[SQLServer].getName)
    sparkConf
      .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}"))
      .set("spark.sql.crossJoin.enabled", "true")
  }

  lazy val sqlConf: SQLConf = _sqlContext.map(_.conf).getOrElse {
    val newSqlConf = new SQLConf()
    mergeSparkConf(newSqlConf, sparkConf)
    newSqlConf
  }

  lazy val sqlContext: SQLContext = _sqlContext.getOrElse(newSQLContext(sparkConf))
  lazy val sparkContext: SparkContext = sqlContext.sparkContext
  lazy val sqlServListener: Option[SQLServerListener] = Some(newSQLServerListener(sqlContext))
  lazy val uiTab: Option[SQLServerTab] = newUiTab(sqlContext, sqlServListener.get)

  private[sql] def newSQLContext(conf: SparkConf): SQLContext = {
    def buildSQLContext(f: SparkSessionExtensions => Unit = _ => {}): SQLContext = {
      SparkSession.builder.config(conf).withExtensions(f).enableHiveSupport()
        .getOrCreate().sqlContext
    }
    val builderClassName = conf.get("spark.sql.server.extensions.builder", "")
    if (builderClassName.nonEmpty) {
      // Tries to install user-defined extensions
      try {
        val objName = builderClassName + (if (!builderClassName.endsWith("$")) "$" else "")
        val clazz = Utils.classForName(objName)
        val builder = clazz.getDeclaredField("MODULE$").get(null)
          .asInstanceOf[SparkSessionExtensions => Unit]
        val sqlContext = buildSQLContext(builder)
        logInfo(s"Successfully installed extensions from $builderClassName")
        sqlContext
      } catch {
        case NonFatal(e) =>
          logWarning(s"Failed to install extensions from $builderClassName: " + e.getMessage)
          buildSQLContext()
      }
    } else {
      buildSQLContext()
    }
  }
  def newSQLServerListener(sqlContext: SQLContext): SQLServerListener = {
    val listener = new SQLServerListener(sqlContext.conf)
    sqlContext.sparkContext.addSparkListener(listener)
    listener
  }
  def newUiTab(sqlContext: SQLContext, listener: SQLServerListener): Option[SQLServerTab] = {
    sqlContext.sparkContext.conf.getBoolean("spark.ui.enabled", true) match {
      case true => Some(SQLServerTab(SQLServerEnv.sqlContext.sparkContext, listener))
      case _ => None
    }
  }
} 
Example 123
Source File: SQLServerUtils.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.util

import java.io.File
import java.lang.reflect.Field
import java.util.StringTokenizer

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.server.SQLServerConf._
import org.apache.spark.sql.server.SQLServerEnv
import org.apache.spark.util.Utils

object SQLServerUtils {

  
  def isTesting: Boolean = {
    SQLServerEnv.sparkConf.contains("spark.sql.server.testing") &&
      SQLServerEnv.sparkConf.get("spark.sql.server.testing") == "true"
  }

  def isRunningOnYarn(conf: SQLConf): Boolean = {
    conf.settings.get("spark.master").startsWith("yarn")
  }

  def isKerberosEnabled(conf: SQLConf): Boolean = {
    require(!conf.sqlServerImpersonationEnabled || conf.sqlServerExecutionMode == "multi-context",
      "Impersonation can be enabled in multi-context mode only")
    conf.contains("spark.yarn.keytab") && conf.contains("spark.yarn.principal")
  }

  def kerberosKeytab(conf: SQLConf): String = {
    val key = "spark.yarn.keytab"
    val keytabFilename = conf.getConfString(key)
    require(keytabFilename != null, s"Kerberos requires `$key` to be provided.")
    keytabFilename
  }

  def kerberosPrincipal(conf: SQLConf): String = {
    val key = "spark.yarn.principal"
    val principalName = conf.getConfString(key)
    require(principalName != null, s"Kerberos requires `$key` to be provided.")
    principalName
  }

  def findFileOnClassPath(fileName: String): Option[File] = {
    val classpath = System.getProperty("java.class.path")
    val pathSeparator = System.getProperty("path.separator")
    val tokenizer = new StringTokenizer(classpath, pathSeparator)
    while (tokenizer.hasMoreTokens) {
      val pathElement = tokenizer.nextToken()
      val directoryOrJar = new File(pathElement)
      val absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile
      if (absoluteDirectoryOrJar.isFile) {
        val target = new File(absoluteDirectoryOrJar.getParent, fileName)
        if (target.exists()) {
          return Some(target)
        }
      } else {
        val target = new File(directoryOrJar, fileName)
        if (target.exists()) {
          return Some(target)
        }
      }
    }
    None
  }

  // https://blog.sebastian-daschner.com/entries/changing_env_java
  def injectEnvVar(key: String, value: String): Unit = {
    val clazz = Utils.classForName("java.lang.ProcessEnvironment")
    injectIntoUnmodifiableMap(key, value, clazz)
  }

  private def getDeclaredField(clazz: Class[_], fieldName: String): Field = {
    val field = clazz.getDeclaredField(fieldName)
    field.setAccessible(true)
    field
  }

  private def injectIntoUnmodifiableMap(key: String, value: String, clazz: Class[_]): Unit = {
    val unmodifiableEnvField = getDeclaredField(clazz, "theUnmodifiableEnvironment")
    val unmodifiableEnv = unmodifiableEnvField.get(null)
    val unmodifiableMapClazz = Utils.classForName("java.util.Collections$UnmodifiableMap")
    val field = getDeclaredField(unmodifiableMapClazz, "m")
    field.get(unmodifiableEnv).asInstanceOf[java.util.Map[String, String]].put(key, value)
  }
} 
Example 124
Source File: PgWireProtocolSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.SQLException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String

class PgWireProtocolSuite extends SparkFunSuite {

  val conf = new SQLConf()

  test("DataRow") {
    val v3Protocol = new PgWireProtocol(65536)
    val row = new GenericInternalRow(2)
    row.update(0, 8)
    row.update(1, UTF8String.fromString("abcdefghij"))
    val schema = StructType.fromDDL("a INT, b STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(true, false))
    val data = v3Protocol.DataRow(row, rowConverters)
    val bytes = ByteBuffer.wrap(data)
    assert(bytes.get() === 'D'.toByte)
    assert(bytes.getInt === 28)
    assert(bytes.getShort === 2)
    assert(bytes.getInt === 4)
    assert(bytes.getInt === 8)
    assert(bytes.getInt === 10)
    assert(data.slice(19, 30) === "abcdefghij".getBytes(StandardCharsets.UTF_8))
  }

  test("Fails when message buffer overflowed") {
    val v3Protocol = new PgWireProtocol(4)
    val row = new GenericInternalRow(1)
    row.update(0, UTF8String.fromString("abcdefghijk"))
    val schema = StructType.fromDDL("a STRING")
    val rowConverters = PgRowConverters(conf, schema, Seq(false))
    val errMsg = intercept[SQLException] {
      v3Protocol.DataRow(row, rowConverters)
    }.getMessage
    assert(errMsg.contains(
      "Cannot generate a V3 protocol message because buffer is not enough for the message. " +
        "To avoid this exception, you might set higher value at " +
        "'spark.sql.server.messageBufferSizeInBytes'")
    )
  }
}