package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class UseAliasesForAggregationsInGroupingsSuite extends FunSuite with MockitoSugar {

  val br1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType)
    ))
  }

  val lr1 = LogicalRelation(br1)
  val nameAtt = lr1.output.find(_.name == "name").get
  val ageAtt = lr1.output.find(_.name == "age").get

  test("replace functions in group by") {
    val avgExpr = avg(ageAtt)
    val avgAlias = avgExpr as 'avgAlias
    assertResult(
      lr1.groupBy(avgAlias.toAttribute)(avgAlias)
    )(UseAliasesForFunctionsInGroupings(
      lr1.groupBy(avgExpr)(avgAlias))
    )
    assertResult(
      lr1.select(ageAtt)
    )(UseAliasesForFunctionsInGroupings(
      lr1.select(ageAtt))
      )
    intercept[RuntimeException](
      UseAliasesForFunctionsInGroupings(Aggregate(Seq(avgExpr), Seq(ageAtt), lr1))
    )
  }

}