Repository: spark
Updated Branches:
  refs/heads/master 841721e03 -> 5731af5be


[SQL] Rewrite join implementation to allow streaming of one relation.

Before we were materializing everything in memory.  This also uses the 
projection interface so will be easier to plug in code gen (its ported from 
that branch).

@rxin @liancheng

Author: Michael Armbrust <mich...@databricks.com>

Closes #250 from marmbrus/hashJoin and squashes the following commits:

1ad873e [Michael Armbrust] Change hasNext logic back to the correct version.
8e6f2a2 [Michael Armbrust] Review comments.
1e9fb63 [Michael Armbrust] style
bc0cb84 [Michael Armbrust] Rewrite join implementation to allow streaming of 
one relation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5731af5b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5731af5b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5731af5b

Branch: refs/heads/master
Commit: 5731af5be65ccac831445f351baf040a0d007687
Parents: 841721e
Author: Michael Armbrust <mich...@databricks.com>
Authored: Mon Mar 31 15:23:46 2014 -0700
Committer: Reynold Xin <r...@apache.org>
Committed: Mon Mar 31 15:23:46 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Row.scala    |  10 ++
 .../sql/catalyst/expressions/predicates.scala   |   6 +
 .../scala/org/apache/spark/sql/SQLContext.scala |   2 +-
 .../spark/sql/execution/SparkStrategies.scala   |   6 +-
 .../org/apache/spark/sql/execution/joins.scala  | 127 ++++++++++++++-----
 .../org/apache/spark/sql/hive/HiveContext.scala |   2 +-
 6 files changed, 116 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 31d42b9..6f939e6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
     s"[${this.mkString(",")}]"
 
   def copy(): Row
+
+  /** Returns true if there are any NULL values in this row. */
+  def anyNull: Boolean = {
+    var i = 0
+    while (i < length) {
+      if (isNullAt(i)) { return true }
+      i += 1
+    }
+    false
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 722ff51..02fedd1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
 
+object InterpretedPredicate {
+  def apply(expression: Expression): (Row => Boolean) = {
+    (r: Row) => expression.apply(r).asInstanceOf[Boolean]
+  }
+}
+
 trait Predicate extends Expression {
   self: Product =>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cf3c06a..f950ea0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
     val strategies: Seq[Strategy] =
       TopK ::
       PartialAggregation ::
-      SparkEquiInnerJoin ::
+      HashJoin ::
       ParquetOperations ::
       BasicOperators ::
       CartesianProduct ::

http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 86f9d3e..e35ac0b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
 abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   self: SQLContext#SparkPlanner =>
 
-  object SparkEquiInnerJoin extends Strategy {
+  object HashJoin extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case FilteredOperation(predicates, logical.Join(left, right, Inner, 
condition)) =>
         logger.debug(s"Considering join: ${predicates ++ condition}")
@@ -51,8 +51,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           val leftKeys = joinKeys.map(_._1)
           val rightKeys = joinKeys.map(_._2)
 
-          val joinOp = execution.SparkEquiInnerJoin(
-            leftKeys, rightKeys, planLater(left), planLater(right))
+          val joinOp = execution.HashJoin(
+            leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
 
           // Make sure other conditions are met if present.
           if (otherPredicates.nonEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index f0d2114..c89dae9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -17,21 +17,22 @@
 
 package org.apache.spark.sql.execution
 
-import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, BitSet}
 
-import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
 
-import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Partitioning}
 
-import org.apache.spark.rdd.PartitionLocalRDDFunctions._
+sealed abstract class BuildSide
+case object BuildLeft extends BuildSide
+case object BuildRight extends BuildSide
 
-case class SparkEquiInnerJoin(
+case class HashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    buildSide: BuildSide,
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode {
 
@@ -40,33 +41,93 @@ case class SparkEquiInnerJoin(
   override def requiredChildDistribution =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
+  val (buildPlan, streamedPlan) = buildSide match {
+    case BuildLeft => (left, right)
+    case BuildRight => (right, left)
+  }
+
+  val (buildKeys, streamedKeys) = buildSide match {
+    case BuildLeft => (leftKeys, rightKeys)
+    case BuildRight => (rightKeys, leftKeys)
+  }
+
   def output = left.output ++ right.output
 
-  def execute() = attachTree(this, "execute") {
-    val leftWithKeys = left.execute().mapPartitions { iter =>
-      val generateLeftKeys = new Projection(leftKeys, left.output)
-      iter.map(row => (generateLeftKeys(row), row.copy()))
-    }
+  @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, 
buildPlan.output)
+  @transient lazy val streamSideKeyGenerator =
+    () => new MutableProjection(streamedKeys, streamedPlan.output)
 
-    val rightWithKeys = right.execute().mapPartitions { iter =>
-      val generateRightKeys = new Projection(rightKeys, right.output)
-      iter.map(row => (generateRightKeys(row), row.copy()))
-    }
+  def execute() = {
 
-    // Do the join.
-    val joined = 
filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
-    // Drop join keys and merge input tuples.
-    joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ 
rightTuple) }
-  }
+    buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, 
streamIter) =>
+      // TODO: Use Spark's HashMap implementation.
+      val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
+      var currentRow: Row = null
+
+      // Create a mapping of buildKeys -> rows
+      while (buildIter.hasNext) {
+        currentRow = buildIter.next()
+        val rowKey = buildSideKeyGenerator(currentRow)
+        if(!rowKey.anyNull) {
+          val existingMatchList = hashTable.get(rowKey)
+          val matchList = if (existingMatchList == null) {
+            val newMatchList = new ArrayBuffer[Row]()
+            hashTable.put(rowKey, newMatchList)
+            newMatchList
+          } else {
+            existingMatchList
+          }
+          matchList += currentRow.copy()
+        }
+      }
+
+      new Iterator[Row] {
+        private[this] var currentStreamedRow: Row = _
+        private[this] var currentHashMatches: ArrayBuffer[Row] = _
+        private[this] var currentMatchPosition: Int = -1
 
-  /**
-   * Filters any rows where the any of the join keys is null, ensuring 
three-valued
-   * logic for the equi-join conditions.
-   */
-  protected def filterNulls(rdd: RDD[(Row, Row)]) =
-    rdd.filter {
-      case (key: Seq[_], _) => !key.exists(_ == null)
+        // Mutable per row objects.
+        private[this] val joinRow = new JoinedRow
+
+        private[this] val joinKeys = streamSideKeyGenerator()
+
+        override final def hasNext: Boolean =
+          (currentMatchPosition != -1 && currentMatchPosition < 
currentHashMatches.size) ||
+          (streamIter.hasNext && fetchNext())
+
+        override final def next() = {
+          val ret = joinRow(currentStreamedRow, 
currentHashMatches(currentMatchPosition))
+          currentMatchPosition += 1
+          ret
+        }
+
+        /**
+         * Searches the streamed iterator for the next row that has at least 
one match in hashtable.
+         *
+         * @return true if the search is successful, and false the streamed 
iterator runs out of
+         *         tuples.
+         */
+        private final def fetchNext(): Boolean = {
+          currentHashMatches = null
+          currentMatchPosition = -1
+
+          while (currentHashMatches == null && streamIter.hasNext) {
+            currentStreamedRow = streamIter.next()
+            if (!joinKeys(currentStreamedRow).anyNull) {
+              currentHashMatches = hashTable.get(joinKeys.currentValue)
+            }
+          }
+
+          if (currentHashMatches == null) {
+            false
+          } else {
+            currentMatchPosition = 0
+            true
+          }
+        }
+      }
     }
+  }
 }
 
 case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends 
BinaryNode {
@@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
   def right = broadcast
 
   @transient lazy val boundCondition =
-    condition
-      .map(c => BindReferences.bindReference(c, left.output ++ right.output))
-      .getOrElse(Literal(true))
+    InterpretedPredicate(
+      condition
+        .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+        .getOrElse(Literal(true)))
 
 
   def execute() = {
     val broadcastedRelation = 
sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
     val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter 
=>
-      val matchedRows = new mutable.ArrayBuffer[Row]
-      val includedBroadcastTuples =  new 
mutable.BitSet(broadcastedRelation.value.size)
+      val matchedRows = new ArrayBuffer[Row]
+      // TODO: Use Spark's BitSet.
+      val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
       val joinedRow = new JoinedRow
 
       streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
         while (i < broadcastedRelation.value.size) {
           // TODO: One bitset per partition instead of per row.
           val broadcastedRow = broadcastedRelation.value(i)
-          if (boundCondition(joinedRow(streamedRow, 
broadcastedRow)).asInstanceOf[Boolean]) {
+          if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
             matchedRows += buildRow(streamedRow ++ broadcastedRow)
             matched = true
             includedBroadcastTuples += i

http://git-wip-us.apache.org/repos/asf/spark/blob/5731af5b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index fc5057b..197b557 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
       DataSinks,
       Scripts,
       PartialAggregation,
-      SparkEquiInnerJoin,
+      HashJoin,
       BasicOperators,
       CartesianProduct,
       BroadcastNestedLoopJoin

Reply via email to