/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.execution import scala.collection.mutable.HashSet import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** * Contains methods for debugging query execution. * * Usage: * {{{ * import org.apache.spark.sql.execution.debug._ * sql("SELECT key FROM src").debug() * dataFrame.typeCheck() * }}} */ package object debug { /** * Augments [[SQLContext]] with debug methods. */ implicit class DebugSQLContext(sqlContext: SQLContext) { def debug(): Unit = { sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) } } /** * Augments [[DataFrame]]s with debug methods. */ implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => visited += new TreeNodeRef(s) DebugNode(s) } logDebug(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => } } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { def zero(initialValue: HashSet[String]): HashSet[String] = { initialValue.clear() initialValue } def addInPlace(v1: HashSet[String], v2: HashSet[String]): HashSet[String] = { v1 ++= v2 v1 } } /** * A collection of metrics for each column of output. * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { logDebug(s"== ${child.simpleString} ==") logDebug(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case(attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 var i = 0 while (i < numColumns) { val value = currentRow.get(i, output(i).dataType) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } i += 1 } currentRow } } } } } }