package com.kakao.cuesheet.convert

import org.apache.spark.HashPartitioner
import org.apache.spark.rdd.RDD

import scala.reflect.ClassTag

class JoinableRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) {

  def selfJoin(numPartitions: Int = rdd.partitions.length): RDD[(K, (V, V))] = fastJoin(rdd, numPartitions)

  def fastJoin[W](other: RDD[(K, W)], numPartitions: Int = rdd.partitions.length): RDD[(K, (V, W))] = {
    val partitioner = new HashPartitioner(numPartitions)
    val grouped = rdd cogroup other

    val left = grouped.flatMap{
      case (k, (vs, ws)) => vs.zipWithIndex.map {
        case (v, idx) => ((k, idx), v)
      }
    }.partitionBy(partitioner)

    val right = grouped.flatMap {
      case (k, (vs, ws)) => ws.map { w => ((k, w.hashCode()), (w, vs.size)) }
    }.partitionBy(partitioner).flatMap {
      case ((k, r), (w, size)) => (0 until size).map(i => ((k, w), i))
    }.map {
      case ((k, w), idx) => ((k, idx), w)
    }

    (left join right).map {
      case ((k, idx), (v, w)) => (k, (v, w))
    }
  }

}