package com.godatadriven.join

import com.godatadriven.SparkUtil
import com.godatadriven.common.Config
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.annotation.tailrec

object IterativeBroadcastJoin extends JoinStrategy {

  @tailrec
  private def iterativeBroadcastJoin(spark: SparkSession,
                                     result: DataFrame,
                                     broadcast: DataFrame,
                                     iteration: Int = 0): DataFrame =
    if (iteration < Config.numberOfBroadcastPasses) {
      val tableName = s"tmp_broadcast_table_itr_$iteration.parquet"

      val out = result.join(
        broadcast.filter(col("pass") === lit(iteration)),
        Seq("key"),
        "left_outer"
      ).select(
        result("key"),

        // Join in the label
        coalesce(
          result("label"),
          broadcast("label")
        ).as("label")
      )

      SparkUtil.dfWrite(out, tableName)

      iterativeBroadcastJoin(
        spark,
        SparkUtil.dfRead(spark, tableName),
        broadcast,
        iteration + 1
      )
    } else result

  override def join(spark: SparkSession,
                    dfLarge: DataFrame,
                    dfMedium: DataFrame): DataFrame = {
    broadcast(dfMedium)
    iterativeBroadcastJoin(
      spark,
      dfLarge
        .select("key")
        .withColumn("label", lit(null)),
      dfMedium
    )
  }

}