/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.sources.druid import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan import org.sparklinedata.druid._ class PostAggregate(val druidOpSchema : DruidOperatorSchema) { val dqb = druidOpSchema.dqb private def attrRef(dOpAttr : DruidOperatorAttribute) : AttributeReference = AttributeReference(dOpAttr.name, dOpAttr.dataType)(dOpAttr.exprId) lazy val groupExpressions = dqb.dimensions.map { d => attrRef(druidOpSchema.druidAttrMap(d.outputName)) } def namedGroupingExpressions = groupExpressions private def toSparkAgg(dAggSpec : AggregationSpec) : Option[AggregateFunction] = { val dOpAttr = druidOpSchema.druidAttrMap(dAggSpec.name) dAggSpec match { case FunctionAggregationSpec("count", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("longSum", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleSum", nm, _) => Some(Sum(attrRef(dOpAttr))) case FunctionAggregationSpec("longMin", nm, _) => Some(Min(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleMin", nm, _) => Some(Min(attrRef(dOpAttr))) case FunctionAggregationSpec("longMax", nm, _) => Some(Max(attrRef(dOpAttr))) case FunctionAggregationSpec("doubleMax", nm, _) => Some(Max(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MIN") => Some(Min(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("MAX") => Some(Max(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("SUM") => Some(Sum(attrRef(dOpAttr))) case JavascriptAggregationSpec(_, aggnm, _, _, _, _) if aggnm.startsWith("COUNT") => Some(Sum(attrRef(dOpAttr))) case _ => None } } lazy val aggregatesO : Option[List[NamedExpression]] = Utils.sequence( dqb.aggregations.map { da => val dOpAttr = druidOpSchema.druidAttrMap(da.name) toSparkAgg(da).map { aggFunc => Alias(AggregateExpression(aggFunc, Complete, false), dOpAttr.name)(dOpAttr.exprId) } }) def canBeExecutedInHistorical : Boolean = dqb.canPushToHistorical && aggregatesO.isDefined lazy val resultExpressions = groupExpressions ++ aggregatesO.get lazy val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression => agg } }.distinct lazy val aggregateFunctionToAttribute = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> attribute }.toMap lazy val rewrittenResultExpressions = resultExpressions.map { expr => expr.transformDown { case aE@AggregateExpression(aggregateFunction, _, isDistinct, _) => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: // aggregateFunctionToAttribute(aggregateFunction, isDistinct) aE.resultAttribute case expression => expression }.asInstanceOf[NamedExpression] } def aggOp(child : SparkPlan) : Seq[SparkPlan] = { org.apache.spark.sql.execution.aggregate.AggUtils.planAggregateWithoutPartial( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) } }