package org.apache.spark.sql import com.collective.TestSparkContext import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.scalatest.FlatSpec import org.apache.spark.sql.functions._ import org.apache.spark.sql.ext.functions._ import scala.collection.mutable class ExtAggregatesSpec extends FlatSpec with TestSparkContext { val schema = StructType(Seq( StructField("cookie_id", StringType), StructField("site", StringType), StructField("impressions", LongType) )) val cookie1 = "cookie1" val cookie2 = "cookie2" val cookie3 = "cookie3" val impressionLog = sqlContext.createDataFrame(sc.parallelize(Seq( Row(cookie1, "google.com", 10L), Row(cookie1, "cnn.com", 14L), Row(cookie1, "google.com", 2L), Row(cookie2, "bbc.com", 20L), Row(cookie2, "auto.com", null), Row(cookie2, "auto.com", 1L), Row(cookie3, "sport.com", 100L) )), schema) "Ext Aggregates" should "collect column values as array" in { val cookies = impressionLog .select(collectArray(col("cookie_id"))) .first().getAs[mutable.WrappedArray[String]](0) assert(cookies.length == 7) assert(cookies.toSet.size == 3) } it should "collect distinct values as array" in { val distinctCookies = impressionLog.select(col("cookie_id")) .distinct() .select(collectArray(col("cookie_id"))) .first().getAs[mutable.WrappedArray[String]](0) assert(distinctCookies.length == 3) } it should "collect values after group by" in { val result = impressionLog .groupBy(col("cookie_id")) .agg(collectArray(col("site"))) val cookieSites = result.collect().map { case Row(cookie: String, sites: mutable.WrappedArray[_]) => cookie -> sites.toSeq }.toMap assert(cookieSites(cookie1).length == 3) assert(cookieSites(cookie2).length == 3) assert(cookieSites(cookie3).length == 1) } }