/* * Copyright 2016 by Simba Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package org.apache.spark.sql.simba.execution.join import org.apache.spark.sql.simba.execution.SimbaPlan import org.apache.spark.sql.simba.index.RTree import org.apache.spark.sql.simba.partitioner.{MapDPartition, STRPartition} import org.apache.spark.sql.simba.spatial.Point import org.apache.spark.sql.simba.util.{NumberUtil, ShapeUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Literal} import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable /** * Created by dong on 1/20/16. * Distance Join based on SJMR(Spatial Join MapReduce) */ case class DJSpark(left_key: Expression, right_key: Expression, l: Literal, left: SparkPlan, right: SparkPlan) extends SimbaPlan { override def output: Seq[Attribute] = left.output ++ right.output final val num_partitions = simbaSessionState.simbaConf.joinPartitions final val sample_rate = simbaSessionState.simbaConf.sampleRate final val max_entries_per_node = simbaSessionState.simbaConf.maxEntriesPerNode final val transfer_threshold = simbaSessionState.simbaConf.transferThreshold final val r = NumberUtil.literalToDouble(l) override protected def doExecute(): RDD[InternalRow] = { val left_rdd = left.execute().map(row => (ShapeUtils.getShape(left_key, left.output, row).asInstanceOf[Point], row) ) val right_rdd = right.execute().map(row => (ShapeUtils.getShape(right_key, right.output, row).asInstanceOf[Point], row) ) val dimension = right_rdd.first()._1.coord.length val (left_partitioned, left_mbr_bound) = STRPartition(left_rdd, dimension, num_partitions, sample_rate, transfer_threshold, max_entries_per_node) val (right_partitioned, right_mbr_bound) = STRPartition(right_rdd, dimension, num_partitions, sample_rate, transfer_threshold, max_entries_per_node) val right_rt = RTree(right_mbr_bound.zip(Array.fill[Int](right_mbr_bound.length)(0)) .map(x => (x._1._1, x._1._2, x._2)), max_entries_per_node) val left_dup = new Array[Array[Int]](left_mbr_bound.length) val right_dup = new Array[Array[Int]](right_mbr_bound.length) var tot = 0 left_mbr_bound.foreach { now => val res = right_rt.circleRange(now._1, r) val tmp_arr = mutable.ArrayBuffer[Int]() res.foreach {x => if (right_dup(x._2) == null) right_dup(x._2) = Array(tot) else right_dup(x._2) = right_dup(x._2) :+ tot tmp_arr += tot tot += 1 } left_dup(now._2) = tmp_arr.toArray } val bc_left_dup = sparkContext.broadcast(left_dup) val bc_right_dup = sparkContext.broadcast(right_dup) val left_dup_rdd = left_partitioned.mapPartitionsWithIndex { (id, iter) => iter.flatMap {now => val tmp_list = bc_left_dup.value(id) if (tmp_list != null) tmp_list.map(x => (x, now)) else Array[(Int, (Point, InternalRow))]() } } val right_dup_rdd = right_partitioned.mapPartitionsWithIndex { (id, iter) => iter.flatMap {now => val tmp_list = bc_right_dup.value(id) if (tmp_list != null) tmp_list.map(x => (x, now)) else Array[(Int, (Point, InternalRow))]() } } val left_dup_partitioned = MapDPartition(left_dup_rdd, tot).map(_._2) val right_dup_partitioned = MapDPartition(right_dup_rdd, tot).map(_._2) left_dup_partitioned.zipPartitions(right_dup_partitioned) {(leftIter, rightIter) => val ans = mutable.ListBuffer[InternalRow]() val right_data = rightIter.toArray if (right_data.nonEmpty) { val right_index = RTree(right_data.map(_._1).zipWithIndex, max_entries_per_node) leftIter.foreach {now => ans ++= right_index.circleRange(now._1, r) .map(x => new JoinedRow(now._2, right_data(x._2)._2)) } } ans.iterator } } override def children: Seq[SparkPlan] = Seq(left, right) }