package org.biodatageeks.sequila.rangejoins.methods.genApp import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedProjection, UnsafeRow} import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.biodatageeks.sequila.rangejoins.genApp.Interval @DeveloperApi case class IntervalTreeJoinChromosome(left: SparkPlan, right: SparkPlan, condition: Seq[Expression], context: SparkSession) extends BinaryExecNode { def output = left.output ++ right.output lazy val (buildPlan, streamedPlan) = (left, right) lazy val (buildKeys, streamedKeys) = (List(condition(0), condition(1),condition(4)), List(condition(2), condition(3),condition(5))) @transient lazy val buildKeyGenerator = new InterpretedProjection(buildKeys, buildPlan.output) @transient lazy val streamKeyGenerator = new InterpretedProjection(streamedKeys, streamedPlan.output) protected override def doExecute(): RDD[InternalRow] = { val v1 = left.execute() val v1kv = v1.map(x => { val v1Key = buildKeyGenerator(x) ((v1Key.getString(2),new Interval[Int](v1Key.getInt(0), v1Key.getInt(1))), x.copy()) }) val v2 = right.execute() val v2kv = v2.map(x => { val v2Key = streamKeyGenerator(x) ((v2Key.getString(2),new Interval[Int](v2Key.getInt(0), v2Key.getInt(1))), x.copy()) }) /* As we are going to collect v1 and build an interval tree on its intervals, make sure that its size is the smaller one. */ if (v1.count <= v2.count) { val v3 = IntervalTreeJoinChromosomeImpl.overlapJoin(context.sparkContext, v1kv, v2kv) .flatMap(l => l._2 .map(r => (l._1, r))) v3.mapPartitions( p => { val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) p.map(r => joiner.join(r._1.asInstanceOf[UnsafeRow], r._2.asInstanceOf[UnsafeRow])) } ) } else { val v3 = IntervalTreeJoinChromosomeImpl.overlapJoin(context.sparkContext, v2kv, v1kv).flatMap(l => l._2.map(r => (l._1, r))) v3.mapPartitions( p => { val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema) p.map(r=>joiner.join(r._2.asInstanceOf[UnsafeRow],r._1.asInstanceOf[UnsafeRow])) } ) } } }