Repository: spark
Updated Branches:
  refs/heads/branch-1.5 01efa4f27 -> f9beef998


[SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer 
join

This patch adds a new `SortMergeOuterJoin` operator that performs left and 
right outer joins using sort merge join.  It also refactors `SortMergeJoin` in 
order to improve performance and code clarity.

Along the way, I also performed a couple pieces of minor cleanup and 
optimization:

- Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since 
it's also used for non-hash joins.
- Rewrite the comment at the top of `HashJoin` to better explain the precedence 
for choosing join operators.
- Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf 
settings.

This patch incorporates several ideas from adrian-wang's patch, #5717.

Closes #5717.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png"; height=40 alt="Review on 
Reviewable"/>](https://reviewable.io/reviews/apache/spark/7904)
<!-- Reviewable:end -->

Author: Josh Rosen <joshro...@databricks.com>
Author: Daoyuan Wang <daoyuan.w...@intel.com>

Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits.

(cherry picked from commit 91e9389f39509e63654bd4bcb7bd919eaedda910)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9beef99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9beef99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9beef99

Branch: refs/heads/branch-1.5
Commit: f9beef9987c6a1993990e6695fb295019e5ed5d3
Parents: 01efa4f
Author: Josh Rosen <joshro...@databricks.com>
Authored: Mon Aug 10 22:04:41 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Aug 10 22:04:50 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/JoinedRow.scala    |   6 +-
 .../scala/org/apache/spark/sql/SQLContext.scala |   2 +-
 .../spark/sql/execution/RowIterator.scala       |  93 ++++++
 .../spark/sql/execution/SparkStrategies.scala   |  45 ++-
 .../joins/BroadcastNestedLoopJoin.scala         |   5 +-
 .../sql/execution/joins/SortMergeJoin.scala     | 331 +++++++++++++------
 .../execution/joins/SortMergeOuterJoin.scala    | 251 ++++++++++++++
 .../scala/org/apache/spark/sql/JoinSuite.scala  | 132 ++++----
 .../sql/execution/joins/InnerJoinSuite.scala    | 180 ++++++++++
 .../sql/execution/joins/OuterJoinSuite.scala    | 310 ++++++++++++-----
 .../sql/execution/joins/SemiJoinSuite.scala     | 125 ++++---
 .../apache/spark/sql/test/SQLTestUtils.scala    |   2 +-
 .../org/apache/spark/sql/hive/HiveContext.scala |   2 +-
 13 files changed, 1165 insertions(+), 319 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
index b76757c..d3560df 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
@@ -37,20 +37,20 @@ class JoinedRow extends InternalRow {
   }
 
   /** Updates this JoinedRow to used point at two new base rows.  Returns 
itself. */
-  def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
+  def apply(r1: InternalRow, r2: InternalRow): JoinedRow = {
     row1 = r1
     row2 = r2
     this
   }
 
   /** Updates this JoinedRow by updating its left base row.  Returns itself. */
-  def withLeft(newLeft: InternalRow): InternalRow = {
+  def withLeft(newLeft: InternalRow): JoinedRow = {
     row1 = newLeft
     this
   }
 
   /** Updates this JoinedRow by updating its right base row.  Returns itself. 
*/
-  def withRight(newRight: InternalRow): InternalRow = {
+  def withRight(newRight: InternalRow): JoinedRow = {
     row2 = newRight
     this
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f73bb04..4bf00b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       HashAggregation ::
       Aggregation ::
       LeftSemiJoin ::
-      HashJoin ::
+      EquiJoinSelection ::
       InMemoryScans ::
       BasicOperators ::
       CartesianProduct ::

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
new file mode 100644
index 0000000..7462dbc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * An internal iterator interface which presents a more restrictive API than
+ * [[scala.collection.Iterator]].
+ *
+ * One major departure from the Scala iterator API is the fusing of the 
`hasNext()` and `next()`
+ * calls: Scala's iterator allows users to call `hasNext()` without 
immediately advancing the
+ * iterator to consume the next row, whereas RowIterator combines these calls 
into a single
+ * [[advanceNext()]] method.
+ */
+private[sql] abstract class RowIterator {
+  /**
+   * Advance this iterator by a single row. Returns `false` if this iterator 
has no more rows
+   * and `true` otherwise. If this returns `true`, then the new row can be 
retrieved by calling
+   * [[getRow]].
+   */
+  def advanceNext(): Boolean
+
+  /**
+   * Retrieve the row from this iterator. This method is idempotent. It is 
illegal to call this
+   * method after [[advanceNext()]] has returned `false`.
+   */
+  def getRow: InternalRow
+
+  /**
+   * Convert this RowIterator into a [[scala.collection.Iterator]].
+   */
+  def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
+}
+
+object RowIterator {
+  def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = {
+    scalaIter match {
+      case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter
+      case _ => new RowIteratorFromScala(scalaIter)
+    }
+  }
+}
+
+private final class RowIteratorToScala(val rowIter: RowIterator) extends 
Iterator[InternalRow] {
+  private [this] var hasNextWasCalled: Boolean = false
+  private [this] var _hasNext: Boolean = false
+  override def hasNext: Boolean = {
+    // Idempotency:
+    if (!hasNextWasCalled) {
+      _hasNext = rowIter.advanceNext()
+      hasNextWasCalled = true
+    }
+    _hasNext
+  }
+  override def next(): InternalRow = {
+    if (!hasNext) throw new NoSuchElementException
+    hasNextWasCalled = false
+    rowIter.getRow
+  }
+}
+
+private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) 
extends RowIterator {
+  private[this] var _next: InternalRow = null
+  override def advanceNext(): Boolean = {
+    if (scalaIter.hasNext) {
+      _next = scalaIter.next()
+      true
+    } else {
+      _next = null
+      false
+    }
+  }
+  override def getRow: InternalRow = _next
+  override def toScala: Iterator[InternalRow] = scalaIter
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c4b9b5a..1fc870d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   }
 
   /**
-   * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of 
the predicates can be
-   * evaluated by matching hash keys.
+   * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least 
some of the predicates
+   * can be evaluated by matching join keys.
    *
-   * This strategy applies a simple optimization based on the estimates of the 
physical sizes of
-   * the two join sides.  When planning a [[joins.BroadcastHashJoin]], if one 
side has an
-   * estimated physical size smaller than the user-settable threshold
-   * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the 
planner would mark it as the
-   * ''build'' relation and mark the other relation as the ''stream'' side.  
The build table will be
-   * ''broadcasted'' to all of the executors involved in the join, as a
-   * [[org.apache.spark.broadcast.Broadcast]] object.  If both estimates 
exceed the threshold, they
-   * will instead be used to decide the build side in a 
[[joins.ShuffledHashJoin]].
+   * Join implementations are chosen with the following precedence:
+   *
+   * - Broadcast: if one side of the join has an estimated physical size that 
is smaller than the
+   *     user-configurable 
[[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
+   *     or if that side has an explicit broadcast hint (e.g. the user applied 
the
+   *     [[org.apache.spark.sql.functions.broadcast()]] function to a 
DataFrame), then that side
+   *     of the join will be broadcasted and the other side will be streamed, 
with no shuffling
+   *     performed. If both sides of the join are eligible to be broadcasted 
then the
+   * - Sort merge: if the matching join keys are sortable and
+   *     [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), 
then sort merge join
+   *     will be used.
+   * - Hash: will be chosen if neither of the above optimizations apply to 
this join.
    */
-  object HashJoin extends Strategy with PredicateHelper {
+  object EquiJoinSelection extends Strategy with PredicateHelper {
 
     private[this] def makeBroadcastHashJoin(
         leftKeys: Seq[Expression],
@@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     }
 
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+
+      // --- Inner joins 
--------------------------------------------------------------------------
+
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, 
CanBroadcast(right)) =>
         makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, 
joins.BuildRight)
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, 
CanBroadcast(left), right) =>
         makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, 
joins.BuildLeft)
 
-      // If the sort merge join option is set, we want to use sort merge join 
prior to hashjoin
-      // for now let's support inner join first, then add outer join
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, 
right)
         if sqlContext.conf.sortMergeJoinEnabled && 
RowOrdering.isOrderable(leftKeys) =>
         val mergeJoin =
@@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
         condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
 
+      // --- Outer joins 
--------------------------------------------------------------------------
+
       case ExtractEquiJoinKeys(
              LeftOuter, leftKeys, rightKeys, condition, left, 
CanBroadcast(right)) =>
         joins.BroadcastHashOuterJoin(
@@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         joins.BroadcastHashOuterJoin(
           leftKeys, rightKeys, RightOuter, condition, planLater(left), 
planLater(right)) :: Nil
 
+      case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, 
left, right)
+        if sqlContext.conf.sortMergeJoinEnabled && 
RowOrdering.isOrderable(leftKeys) =>
+        joins.SortMergeOuterJoin(
+          leftKeys, rightKeys, LeftOuter, condition, planLater(left), 
planLater(right)) :: Nil
+
+      case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, 
left, right)
+        if sqlContext.conf.sortMergeJoinEnabled && 
RowOrdering.isOrderable(leftKeys) =>
+        joins.SortMergeOuterJoin(
+          leftKeys, rightKeys, RightOuter, condition, planLater(left), 
planLater(right)) :: Nil
+
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right) =>
         joins.ShuffledHashOuterJoin(
           leftKeys, rightKeys, joinType, condition, planLater(left), 
planLater(right)) :: Nil
 
+      // --- Cases where this strategy does not apply 
---------------------------------------------
+
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 23aebf4..017a44b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin(
         left.output.map(_.withNullability(true)) ++ right.output
       case FullOuter =>
         left.output.map(_.withNullability(true)) ++ 
right.output.map(_.withNullability(true))
-      case _ =>
-        left.output ++ right.output
+      case x =>
+        throw new IllegalArgumentException(
+          s"BroadcastNestedLoopJoin should not take $x as the JoinType")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 4ae23c1..6d656ea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.sql.execution.joins
 
-import java.util.NoSuchElementException
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
 
 /**
  * :: DeveloperApi ::
@@ -38,8 +37,6 @@ case class SortMergeJoin(
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode {
 
-  override protected[sql] val trackNumOfRowsEnabled = true
-
   override def output: Seq[Attribute] = left.output ++ right.output
 
   override def outputPartitioning: Partitioning =
@@ -56,117 +53,265 @@ case class SortMergeJoin(
   @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, 
left.output)
   @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, 
right.output)
 
+  protected[this] def isUnsafeMode: Boolean = {
+    (codegenEnabled && unsafeEnabled
+      && UnsafeProjection.canSupport(leftKeys)
+      && UnsafeProjection.canSupport(rightKeys)
+      && UnsafeProjection.canSupport(schema))
+  }
+
+  override def outputsUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessSafeRows: Boolean = !isUnsafeMode
+
   private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
     // This must be ascending in order to agree with the `keyOrdering` defined 
in `doExecute()`.
     keys.map(SortOrder(_, Ascending))
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val leftResults = left.execute().map(_.copy())
-    val rightResults = right.execute().map(_.copy())
-
-    leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
-      new Iterator[InternalRow] {
+    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      new RowIterator {
         // An ordering that can be used to compare keys from both sides.
         private[this] val keyOrdering = 
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
-        // Mutable per row objects.
+        private[this] var currentLeftRow: InternalRow = _
+        private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+        private[this] var currentMatchIdx: Int = -1
+        private[this] val smjScanner = new SortMergeJoinScanner(
+          leftKeyGenerator,
+          rightKeyGenerator,
+          keyOrdering,
+          RowIterator.fromScala(leftIter),
+          RowIterator.fromScala(rightIter)
+        )
         private[this] val joinRow = new JoinedRow
-        private[this] var leftElement: InternalRow = _
-        private[this] var rightElement: InternalRow = _
-        private[this] var leftKey: InternalRow = _
-        private[this] var rightKey: InternalRow = _
-        private[this] var rightMatches: CompactBuffer[InternalRow] = _
-        private[this] var rightPosition: Int = -1
-        private[this] var stop: Boolean = false
-        private[this] var matchKey: InternalRow = _
-
-        // initialize iterator
-        initialize()
-
-        override final def hasNext: Boolean = nextMatchingPair()
-
-        override final def next(): InternalRow = {
-          if (hasNext) {
-            // we are using the buffered right rows and run down left iterator
-            val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
-            rightPosition += 1
-            if (rightPosition >= rightMatches.size) {
-              rightPosition = 0
-              fetchLeft()
-              if (leftElement == null || keyOrdering.compare(leftKey, 
matchKey) != 0) {
-                stop = false
-                rightMatches = null
-              }
-            }
-            joinedRow
+        private[this] val resultProjection: (InternalRow) => InternalRow = {
+          if (isUnsafeMode) {
+            UnsafeProjection.create(schema)
           } else {
-            // no more result
-            throw new NoSuchElementException
+            identity[InternalRow]
           }
         }
 
-        private def fetchLeft() = {
-          if (leftIter.hasNext) {
-            leftElement = leftIter.next()
-            leftKey = leftKeyGenerator(leftElement)
-          } else {
-            leftElement = null
+        override def advanceNext(): Boolean = {
+          if (currentMatchIdx == -1 || currentMatchIdx == 
currentRightMatches.length) {
+            if (smjScanner.findNextInnerJoinRows()) {
+              currentRightMatches = smjScanner.getBufferedMatches
+              currentLeftRow = smjScanner.getStreamedRow
+              currentMatchIdx = 0
+            } else {
+              currentRightMatches = null
+              currentLeftRow = null
+              currentMatchIdx = -1
+            }
           }
-        }
-
-        private def fetchRight() = {
-          if (rightIter.hasNext) {
-            rightElement = rightIter.next()
-            rightKey = rightKeyGenerator(rightElement)
+          if (currentLeftRow != null) {
+            joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+            currentMatchIdx += 1
+            true
           } else {
-            rightElement = null
+            false
           }
         }
 
-        private def initialize() = {
-          fetchLeft()
-          fetchRight()
+        override def getRow: InternalRow = resultProjection(joinRow)
+      }.toScala
+    }
+  }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoin]] and 
[[SortMergeOuterJoin]].
+ *
+ * To perform an inner (outer) join, users of this class call 
[[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been 
produced and `false`
+ * otherwise. If a result has been produced, then the caller may call 
[[getStreamedRow]] to return
+ * the matching row from the streamed input and may call 
[[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an outer 
join, this will return
+ * an empty sequence if there are no matches from the buffered input). For 
efficiency, both of these
+ * methods return mutable objects which are re-used across calls to the 
`findNext*JoinRows()`
+ * methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from the 
streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from the 
buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct 
sequences of rows that
+ *                     have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+    streamedKeyGenerator: Projection,
+    bufferedKeyGenerator: Projection,
+    keyOrdering: Ordering[InternalRow],
+    streamedIter: RowIterator,
+    bufferedIter: RowIterator) {
+  private[this] var streamedRow: InternalRow = _
+  private[this] var streamedRowKey: InternalRow = _
+  private[this] var bufferedRow: InternalRow = _
+  // Note: this is guaranteed to never have any null columns:
+  private[this] var bufferedRowKey: InternalRow = _
+  /**
+   * The join key for the rows buffered in `bufferedMatches`, or null if 
`bufferedMatches` is empty
+   */
+  private[this] var matchJoinKey: InternalRow = _
+  /** Buffered rows from the buffered side of the join. This is empty if there 
are no matches. */
+  private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new 
ArrayBuffer[InternalRow]
+
+  // Initialization (note: do _not_ want to advance streamed here).
+  advancedBufferedToRowWithNullFreeJoinKey()
+
+  // --- Public methods 
---------------------------------------------------------------------------
+
+  def getStreamedRow: InternalRow = streamedRow
+
+  def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+  /**
+   * Advances both input iterators, stopping when we have found rows with 
matching join keys.
+   * @return true if matching rows have been found and false otherwise. If 
this returns true, then
+   *         [[getStreamedRow]] and [[getBufferedMatches]] can be called to 
construct the join
+   *         results.
+   */
+  final def findNextInnerJoinRows(): Boolean = {
+    while (advancedStreamed() && streamedRowKey.anyNull) {
+      // Advance the streamed side of the join until we find the next row 
whose join key contains
+      // no nulls or we hit the end of the streamed iterator.
+    }
+    if (streamedRow == null) {
+      // We have consumed the entire streamed iterator, so there can be no 
more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, 
matchJoinKey) == 0) {
+      // The new streamed row has the same join key as the previous row, so 
return the same matches.
+      true
+    } else if (bufferedRow == null) {
+      // The streamed row's join key does not match the current batch of 
buffered rows and there are
+      // no more rows to read from the buffered iterator, so there can be no 
more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else {
+      // Advance both the streamed and buffered iterators to find the next 
pair of matching rows.
+      var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+      do {
+        if (streamedRowKey.anyNull) {
+          advancedStreamed()
+        } else {
+          assert(!bufferedRowKey.anyNull)
+          comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+          if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
+          else if (comp < 0) advancedStreamed()
         }
+      } while (streamedRow != null && bufferedRow != null && comp != 0)
+      if (streamedRow == null || bufferedRow == null) {
+        // We have either hit the end of one of the iterators, so there can be 
no more matches.
+        matchJoinKey = null
+        bufferedMatches.clear()
+        false
+      } else {
+        // The streamed row's join key matches the current buffered row's 
join, so walk through the
+        // buffered iterator to buffer the rest of the matching rows.
+        assert(comp == 0)
+        bufferMatchingRows()
+        true
+      }
+    }
+  }
 
-        /**
-         * Searches the right iterator for the next rows that have matches in 
left side, and store
-         * them in a buffer.
-         *
-         * @return true if the search is successful, and false if the right 
iterator runs out of
-         *         tuples.
-         */
-        private def nextMatchingPair(): Boolean = {
-          if (!stop && rightElement != null) {
-            // run both side to get the first match pair
-            while (!stop && leftElement != null && rightElement != null) {
-              val comparing = keyOrdering.compare(leftKey, rightKey)
-              // for inner join, we need to filter those null keys
-              stop = comparing == 0 && !leftKey.anyNull
-              if (comparing > 0 || rightKey.anyNull) {
-                fetchRight()
-              } else if (comparing < 0 || leftKey.anyNull) {
-                fetchLeft()
-              }
-            }
-            rightMatches = new CompactBuffer[InternalRow]()
-            if (stop) {
-              stop = false
-              // iterate the right side to buffer all rows that matches
-              // as the records should be ordered, exit when we meet the first 
that not match
-              while (!stop && rightElement != null) {
-                rightMatches += rightElement
-                fetchRight()
-                stop = keyOrdering.compare(leftKey, rightKey) != 0
-              }
-              if (rightMatches.size > 0) {
-                rightPosition = 0
-                matchKey = leftKey
-              }
-            }
+  /**
+   * Advances the streamed input iterator and buffers all rows from the 
buffered input that
+   * have matching keys.
+   * @return true if the streamed iterator returned a row, false otherwise. If 
this returns true,
+   *         then [getStreamedRow and [[getBufferedMatches]] can be called to 
produce the outer
+   *         join results.
+   */
+  final def findNextOuterJoinRows(): Boolean = {
+    if (!advancedStreamed()) {
+      // We have consumed the entire streamed iterator, so there can be no 
more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else {
+      if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, 
matchJoinKey) == 0) {
+        // Matches the current group, so do nothing.
+      } else {
+        // The streamed row does not match the current group.
+        matchJoinKey = null
+        bufferedMatches.clear()
+        if (bufferedRow != null && !streamedRowKey.anyNull) {
+          // The buffered iterator could still contain matching rows, so we'll 
need to walk through
+          // it until we either find matches or pass where they would be found.
+          var comp = 1
+          do {
+            comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+          } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
+          if (comp == 0) {
+            // We have found matches, so buffer them (this updates 
matchJoinKey)
+            bufferMatchingRows()
+          } else {
+            // We have overshot the position where the row would be found, 
hence no matches.
           }
-          rightMatches != null && rightMatches.size > 0
         }
       }
+      // If there is a streamed input then we always return true
+      true
     }
   }
+
+  // --- Private methods 
--------------------------------------------------------------------------
+
+  /**
+   * Advance the streamed iterator and compute the new row's join key.
+   * @return true if the streamed iterator returned a row and false otherwise.
+   */
+  private def advancedStreamed(): Boolean = {
+    if (streamedIter.advanceNext()) {
+      streamedRow = streamedIter.getRow
+      streamedRowKey = streamedKeyGenerator(streamedRow)
+      true
+    } else {
+      streamedRow = null
+      streamedRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Advance the buffered iterator until we find a row with join key that does 
not contain nulls.
+   * @return true if the buffered iterator returned a row and false otherwise.
+   */
+  private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
+    var foundRow: Boolean = false
+    while (!foundRow && bufferedIter.advanceNext()) {
+      bufferedRow = bufferedIter.getRow
+      bufferedRowKey = bufferedKeyGenerator(bufferedRow)
+      foundRow = !bufferedRowKey.anyNull
+    }
+    if (!foundRow) {
+      bufferedRow = null
+      bufferedRowKey = null
+      false
+    } else {
+      true
+    }
+  }
+
+  /**
+   * Called when the streamed and buffered join keys match in order to buffer 
the matching rows.
+   */
+  private def bufferMatchingRows(): Unit = {
+    assert(streamedRowKey != null)
+    assert(!streamedRowKey.anyNull)
+    assert(bufferedRowKey != null)
+    assert(!bufferedRowKey.anyNull)
+    assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+    // This join key may have been produced by a mutable projection, so we 
need to make a copy:
+    matchJoinKey = streamedRowKey.copy()
+    bufferedMatches.clear()
+    do {
+      bufferedMatches += bufferedRow.copy() // need to copy mutable rows 
before buffering them
+      advancedBufferedToRowWithNullFreeJoinKey()
+    } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, 
bufferedRowKey) == 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
new file mode 100644
index 0000000..5326966
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge outer join of two child relations.
+ *
+ * Note: this does not support full outer join yet; see SPARK-9730 for 
progress on this.
+ */
+@DeveloperApi
+case class SortMergeOuterJoin(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode {
+
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case x =>
+        throw new IllegalArgumentException(
+          s"${getClass.getSimpleName} should not take $x as the JoinType")
+    }
+  }
+
+  override def outputPartitioning: Partitioning = joinType match {
+    // For left and right outer joins, the output is partitioned by the 
streamed input's join keys.
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case x =>
+      throw new IllegalArgumentException(
+        s"${getClass.getSimpleName} should not take $x as the JoinType")
+  }
+
+  override def outputOrdering: Seq[SortOrder] = joinType match {
+    // For left and right outer joins, the output is ordered by the streamed 
input's join keys.
+    case LeftOuter => requiredOrders(leftKeys)
+    case RightOuter => requiredOrders(rightKeys)
+    case x => throw new IllegalArgumentException(
+      s"SortMergeOuterJoin should not take $x as the JoinType")
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+    // This must be ascending in order to agree with the `keyOrdering` defined 
in `doExecute()`.
+    keys.map(SortOrder(_, Ascending))
+  }
+
+  private def isUnsafeMode: Boolean = {
+    (codegenEnabled && unsafeEnabled
+      && UnsafeProjection.canSupport(leftKeys)
+      && UnsafeProjection.canSupport(rightKeys)
+      && UnsafeProjection.canSupport(schema))
+  }
+
+  override def outputsUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessUnsafeRows: Boolean = isUnsafeMode
+  override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+  private def createLeftKeyGenerator(): Projection = {
+    if (isUnsafeMode) {
+      UnsafeProjection.create(leftKeys, left.output)
+    } else {
+      newProjection(leftKeys, left.output)
+    }
+  }
+
+  private def createRightKeyGenerator(): Projection = {
+    if (isUnsafeMode) {
+      UnsafeProjection.create(rightKeys, right.output)
+    } else {
+      newProjection(rightKeys, right.output)
+    }
+  }
+
+  override def doExecute(): RDD[InternalRow] = {
+    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      // An ordering that can be used to compare keys from both sides.
+      val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+      val boundCondition: (InternalRow) => Boolean = {
+        condition.map { cond =>
+          newPredicate(cond, left.output ++ right.output)
+        }.getOrElse {
+          (r: InternalRow) => true
+        }
+      }
+      val resultProj: InternalRow => InternalRow = {
+        if (isUnsafeMode) {
+          UnsafeProjection.create(schema)
+        } else {
+          identity[InternalRow]
+        }
+      }
+
+      joinType match {
+        case LeftOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createLeftKeyGenerator(),
+            bufferedKeyGenerator = createRightKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(leftIter),
+            bufferedIter = RowIterator.fromScala(rightIter)
+          )
+          val rightNullRow = new GenericInternalRow(right.output.length)
+          new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, 
resultProj).toScala
+
+        case RightOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createRightKeyGenerator(),
+            bufferedKeyGenerator = createLeftKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(rightIter),
+            bufferedIter = RowIterator.fromScala(leftIter)
+          )
+          val leftNullRow = new GenericInternalRow(left.output.length)
+          new RightOuterIterator(smjScanner, leftNullRow, boundCondition, 
resultProj).toScala
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"SortMergeOuterJoin should not take $x as the JoinType")
+      }
+    }
+  }
+}
+
+
+private class LeftOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    rightNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow
+  ) extends RowIterator {
+  private[this] val joinedRow: JoinedRow = new JoinedRow()
+  private[this] var rightIdx: Int = 0
+  assert(smjScanner.getBufferedMatches.length == 0)
+
+  private def advanceLeft(): Boolean = {
+    rightIdx = 0
+    if (smjScanner.findNextOuterJoinRows()) {
+      joinedRow.withLeft(smjScanner.getStreamedRow)
+      if (smjScanner.getBufferedMatches.isEmpty) {
+        // There are no matching right rows, so return nulls for the right row
+        joinedRow.withRight(rightNullRow)
+      } else {
+        // Find the next row from the right input that satisfied the bound 
condition
+        if (!advanceRightUntilBoundConditionSatisfied()) {
+          joinedRow.withRight(rightNullRow)
+        }
+      }
+      true
+    } else {
+      // Left input has been exhausted
+      false
+    }
+  }
+
+  private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
+    var foundMatch: Boolean = false
+    while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
+      foundMatch = 
boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
+      rightIdx += 1
+    }
+    foundMatch
+  }
+
+  override def advanceNext(): Boolean = {
+    advanceRightUntilBoundConditionSatisfied() || advanceLeft()
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}
+
+private class RightOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    leftNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow
+  ) extends RowIterator {
+  private[this] val joinedRow: JoinedRow = new JoinedRow()
+  private[this] var leftIdx: Int = 0
+  assert(smjScanner.getBufferedMatches.length == 0)
+
+  private def advanceRight(): Boolean = {
+    leftIdx = 0
+    if (smjScanner.findNextOuterJoinRows()) {
+      joinedRow.withRight(smjScanner.getStreamedRow)
+      if (smjScanner.getBufferedMatches.isEmpty) {
+        // There are no matching left rows, so return nulls for the left row
+        joinedRow.withLeft(leftNullRow)
+      } else {
+        // Find the next row from the left input that satisfied the bound 
condition
+        if (!advanceLeftUntilBoundConditionSatisfied()) {
+          joinedRow.withLeft(leftNullRow)
+        }
+      }
+      true
+    } else {
+      // Right input has been exhausted
+      false
+    }
+  }
+
+  private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+    var foundMatch: Boolean = false
+    while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
+      foundMatch = 
boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
+      leftIdx += 1
+    }
+    foundMatch
+  }
+
+  override def advanceNext(): Boolean = {
+    advanceLeftUntilBoundConditionSatisfied() || advanceRight()
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5bef1d8..ae07eaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach
 import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.types.BinaryType
+import org.apache.spark.sql.test.SQLTestUtils
 
 
-class JoinSuite extends QueryTest with BeforeAndAfterEach {
+class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
   // Ensures tables are loaded.
   TestData
 
+  override def sqlContext: SQLContext = 
org.apache.spark.sql.test.TestSQLContext
   lazy val ctx = org.apache.spark.sql.test.TestSQLContext
   import ctx.implicits._
   import ctx.logicalPlanToSparkQuery
@@ -37,7 +38,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, $"x.a" === $"y.a", 
"inner").queryExecution.optimizedPlan
-    val planned = ctx.planner.HashJoin(join)
+    val planned = ctx.planner.EquiJoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -55,6 +56,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       case j: BroadcastNestedLoopJoin => j
       case j: BroadcastLeftSemiJoinHash => j
       case j: SortMergeJoin => j
+      case j: SortMergeOuterJoin => j
     }
 
     assert(operators.size === 1)
@@ -66,7 +68,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
   test("join operator selection") {
     ctx.cacheManager.clearCache()
 
-    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
     Seq(
       ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", 
classOf[LeftSemiJoinHash]),
       ("SELECT * FROM testData LEFT SEMI JOIN testData2", 
classOf[LeftSemiJoinBNL]),
@@ -83,11 +84,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       ("SELECT * FROM testData JOIN testData2 ON key = a", 
classOf[SortMergeJoin]),
       ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", 
classOf[SortMergeJoin]),
       ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", 
classOf[SortMergeJoin]),
-      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", 
classOf[ShuffledHashOuterJoin]),
+      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", 
classOf[SortMergeOuterJoin]),
       ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[ShuffledHashOuterJoin]),
+        classOf[SortMergeOuterJoin]),
       ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[ShuffledHashOuterJoin]),
+        classOf[SortMergeOuterJoin]),
       ("SELECT * FROM testData full outer join testData2 ON key = a",
         classOf[ShuffledHashOuterJoin]),
       ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
@@ -97,82 +98,75 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
         classOf[BroadcastNestedLoopJoin])
     ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    try {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+    withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
       Seq(
-        ("SELECT * FROM testData JOIN testData2 ON key = a", 
classOf[SortMergeJoin]),
-        ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", 
classOf[SortMergeJoin]),
-        ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", 
classOf[SortMergeJoin])
+        ("SELECT * FROM testData JOIN testData2 ON key = a", 
classOf[ShuffledHashJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2",
+          classOf[ShuffledHashJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2",
+          classOf[ShuffledHashJoin]),
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", 
classOf[ShuffledHashOuterJoin]),
+        ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 
2",
+          classOf[ShuffledHashOuterJoin]),
+        ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+          classOf[ShuffledHashOuterJoin]),
+        ("SELECT * FROM testData full outer join testData2 ON key = a",
+          classOf[ShuffledHashOuterJoin])
       ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    } finally {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
     }
   }
 
   test("SortMergeJoin shouldn't work on unsortable columns") {
-    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
-    try {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+    withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
       Seq(
         ("SELECT * FROM arrayData JOIN complexData ON data = a", 
classOf[ShuffledHashJoin])
       ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    } finally {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
     }
   }
 
   test("broadcasted hash join operator selection") {
     ctx.cacheManager.clearCache()
     ctx.sql("CACHE TABLE testData")
-
-    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
-    Seq(
-      ("SELECT * FROM testData join testData2 ON key = a", 
classOf[BroadcastHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a and key = 2", 
classOf[BroadcastHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a where key = 2",
-        classOf[BroadcastHashJoin])
-    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    try {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
-      Seq(
-        ("SELECT * FROM testData join testData2 ON key = a", 
classOf[BroadcastHashJoin]),
-        ("SELECT * FROM testData join testData2 ON key = a and key = 2",
-          classOf[BroadcastHashJoin]),
-        ("SELECT * FROM testData join testData2 ON key = a where key = 2",
-          classOf[BroadcastHashJoin])
-      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    } finally {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
+    for (sortMergeJoinEnabled <- Seq(true, false)) {
+      withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
+        withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
+          Seq(
+            ("SELECT * FROM testData join testData2 ON key = a",
+              classOf[BroadcastHashJoin]),
+            ("SELECT * FROM testData join testData2 ON key = a and key = 2",
+              classOf[BroadcastHashJoin]),
+            ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+              classOf[BroadcastHashJoin])
+          ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+        }
+      }
     }
-
     ctx.sql("UNCACHE TABLE testData")
   }
 
   test("broadcasted hash outer join operator selection") {
     ctx.cacheManager.clearCache()
     ctx.sql("CACHE TABLE testData")
-
-    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
-    Seq(
-      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", 
classOf[ShuffledHashOuterJoin]),
-      ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[BroadcastHashOuterJoin]),
-      ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[BroadcastHashOuterJoin])
-    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    try {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+    withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
       Seq(
-        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", 
classOf[ShuffledHashOuterJoin]),
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
+          classOf[SortMergeOuterJoin]),
+        ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 
2",
+          classOf[BroadcastHashOuterJoin]),
+        ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+          classOf[BroadcastHashOuterJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    }
+    withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
+      Seq(
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
+          classOf[ShuffledHashOuterJoin]),
         ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 
2",
           classOf[BroadcastHashOuterJoin]),
         ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
           classOf[BroadcastHashOuterJoin])
       ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
-    } finally {
-      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
     }
-
     ctx.sql("UNCACHE TABLE testData")
   }
 
@@ -180,7 +174,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === 
$"y.b")).queryExecution.optimizedPlan
-    val planned = ctx.planner.HashJoin(join)
+    val planned = ctx.planner.EquiJoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -457,25 +451,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach 
{
   test("broadcasted left semi join operator selection") {
     ctx.cacheManager.clearCache()
     ctx.sql("CACHE TABLE testData")
-    val tmp = ctx.conf.autoBroadcastJoinThreshold
 
-    ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000")
-    Seq(
-      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
-        classOf[BroadcastLeftSemiJoinHash])
-    ).foreach {
-      case (query, joinClass) => assertJoin(query, joinClass)
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
+      Seq(
+        ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+          classOf[BroadcastLeftSemiJoinHash])
+      ).foreach {
+        case (query, joinClass) => assertJoin(query, joinClass)
+      }
     }
 
-    ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
-
-    Seq(
-      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", 
classOf[LeftSemiJoinHash])
-    ).foreach {
-      case (query, joinClass) => assertJoin(query, joinClass)
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      Seq(
+        ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", 
classOf[LeftSemiJoinHash])
+      ).foreach {
+        case (query, joinClass) => assertJoin(query, joinClass)
+      }
     }
 
-    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp)
     ctx.sql("UNCACHE TABLE testData")
   }
 
@@ -488,6 +481,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
         Row(2, 2) ::
         Row(3, 1) ::
         Row(3, 2) :: Nil)
-
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
new file mode 100644
index 0000000..ddff7ce
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution._
+
+class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
+
+  private def testInnerJoin(
+      testName: String,
+      leftRows: DataFrame,
+      rightRows: DataFrame,
+      condition: Expression,
+      expectedAnswer: Seq[Product]): Unit = {
+    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition))
+    ExtractEquiJoinKeys.unapply(join).foreach {
+      case (joinType, leftKeys, rightKeys, boundCondition, leftChild, 
rightChild) =>
+
+        def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: 
BuildSide) = {
+          val broadcastHashJoin =
+            execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, 
right)
+          boundCondition.map(Filter(_, 
broadcastHashJoin)).getOrElse(broadcastHashJoin)
+        }
+
+        def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: 
BuildSide) = {
+          val shuffledHashJoin =
+            execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, 
right)
+          val filteredJoin =
+            boundCondition.map(Filter(_, 
shuffledHashJoin)).getOrElse(shuffledHashJoin)
+          EnsureRequirements(sqlContext).apply(filteredJoin)
+        }
+
+        def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
+          val sortMergeJoin =
+            execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
+          val filteredJoin = boundCondition.map(Filter(_, 
sortMergeJoin)).getOrElse(sortMergeJoin)
+          EnsureRequirements(sqlContext).apply(filteredJoin)
+        }
+
+        test(s"$testName using BroadcastHashJoin (build=left)") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              makeBroadcastHashJoin(left, right, joins.BuildLeft),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+
+        test(s"$testName using BroadcastHashJoin (build=right)") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              makeBroadcastHashJoin(left, right, joins.BuildRight),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+
+        test(s"$testName using ShuffledHashJoin (build=left)") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              makeShuffledHashJoin(left, right, joins.BuildLeft),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+
+        test(s"$testName using ShuffledHashJoin (build=right)") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              makeShuffledHashJoin(left, right, joins.BuildRight),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+
+        test(s"$testName using SortMergeJoin") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              makeSortMergeJoin(left, right),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+    }
+  }
+
+  {
+    val upperCaseData = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+      Row(1, "A"),
+      Row(2, "B"),
+      Row(3, "C"),
+      Row(4, "D"),
+      Row(5, "E"),
+      Row(6, "F"),
+      Row(null, "G")
+    )), new StructType().add("N", IntegerType).add("L", StringType))
+
+    val lowerCaseData = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+      Row(1, "a"),
+      Row(2, "b"),
+      Row(3, "c"),
+      Row(4, "d"),
+      Row(null, "e")
+    )), new StructType().add("n", IntegerType).add("l", StringType))
+
+    testInnerJoin(
+      "inner join, one match per row",
+      upperCaseData,
+      lowerCaseData,
+      (upperCaseData.col("N") === lowerCaseData.col("n")).expr,
+      Seq(
+        (1, "A", 1, "a"),
+        (2, "B", 2, "b"),
+        (3, "C", 3, "c"),
+        (4, "D", 4, "d")
+      )
+    )
+  }
+
+  private val testData2 = Seq(
+    (1, 1),
+    (1, 2),
+    (2, 1),
+    (2, 2),
+    (3, 1),
+    (3, 2)
+  ).toDF("a", "b")
+
+  {
+    val left = testData2.where("a = 1")
+    val right = testData2.where("a = 1")
+    testInnerJoin(
+      "inner join, multiple matches",
+      left,
+      right,
+      (left.col("a") === right.col("a")).expr,
+      Seq(
+        (1, 1, 1, 1),
+        (1, 1, 1, 2),
+        (1, 2, 1, 1),
+        (1, 2, 1, 2)
+      )
+    )
+  }
+
+  {
+    val left = testData2.where("a = 1")
+    val right = testData2.where("a = 2")
+    testInnerJoin(
+      "inner join, no matches",
+      left,
+      right,
+      (left.col("a") === right.col("a")).expr,
+      Seq.empty
+    )
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 2c27da5..e16f5e3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -1,89 +1,221 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
-
-class OuterJoinSuite extends SparkPlanTest {
-
-  val left = Seq(
-    (1, 2.0),
-    (2, 1.0),
-    (3, 3.0)
-  ).toDF("a", "b")
-
-  val right = Seq(
-    (2, 3.0),
-    (3, 2.0),
-    (4, 1.0)
-  ).toDF("c", "d")
-
-  val leftKeys: List[Expression] = 'a :: Nil
-  val rightKeys: List[Expression] = 'c :: Nil
-  val condition = Some(LessThan('b, 'd))
-
-  test("shuffled hash outer join") {
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, 
right),
-      Seq(
-        (1, 2.0, null, null),
-        (2, 1.0, 2, 3.0),
-        (3, 3.0, null, null)
-      ).map(Row.fromTuple))
-
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, 
right),
-      Seq(
-        (2, 1.0, 2, 3.0),
-        (null, null, 3, 2.0),
-        (null, null, 4, 1.0)
-      ).map(Row.fromTuple))
-
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, 
right),
-      Seq(
-        (1, 2.0, null, null),
-        (2, 1.0, 2, 3.0),
-        (3, 3.0, null, null),
-        (null, null, 3, 2.0),
-        (null, null, 4, 1.0)
-      ).map(Row.fromTuple))
-  }
-
-  test("broadcast hash outer join") {
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, 
right),
-      Seq(
-        (1, 2.0, null, null),
-        (2, 1.0, 2, 3.0),
-        (3, 3.0, null, null)
-      ).map(Row.fromTuple))
-
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, 
right),
-      Seq(
-        (2, 1.0, 2, 3.0),
-        (null, null, 3, 2.0),
-        (null, null, 4, 1.0)
-      ).map(Row.fromTuple))
-  }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, 
SparkPlanTest}
+
+class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
+
+  private def testOuterJoin(
+      testName: String,
+      leftRows: DataFrame,
+      rightRows: DataFrame,
+      joinType: JoinType,
+      condition: Expression,
+      expectedAnswer: Seq[Product]): Unit = {
+    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition))
+    ExtractEquiJoinKeys.unapply(join).foreach {
+      case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
+        test(s"$testName using ShuffledHashOuterJoin") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              EnsureRequirements(sqlContext).apply(
+                ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, 
boundCondition, left, right)),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+
+        if (joinType != FullOuter) {
+          test(s"$testName using BroadcastHashOuterJoin") {
+            withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+              checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+                BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, 
boundCondition, left, right),
+                expectedAnswer.map(Row.fromTuple),
+                sortAnswers = true)
+            }
+          }
+
+          test(s"$testName using SortMergeOuterJoin") {
+            withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+              checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+                EnsureRequirements(sqlContext).apply(
+                  SortMergeOuterJoin(leftKeys, rightKeys, joinType, 
boundCondition, left, right)),
+                expectedAnswer.map(Row.fromTuple),
+                sortAnswers = false)
+            }
+          }
+        }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) 
=>
+          joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, 
joinType, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) 
=>
+          joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, 
joinType, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+  }
+
+  val left = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+    Row(1, 2.0),
+    Row(2, 100.0),
+    Row(2, 1.0), // This row is duplicated to ensure that we will have 
multiple buffered matches
+    Row(2, 1.0),
+    Row(3, 3.0),
+    Row(5, 1.0),
+    Row(6, 6.0),
+    Row(null, null)
+  )), new StructType().add("a", IntegerType).add("b", DoubleType))
+
+  val right = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+    Row(0, 0.0),
+    Row(2, 3.0), // This row is duplicated to ensure that we will have 
multiple buffered matches
+    Row(2, -1.0),
+    Row(2, -1.0),
+    Row(2, 3.0),
+    Row(3, 2.0),
+    Row(4, 1.0),
+    Row(5, 3.0),
+    Row(7, 7.0),
+    Row(null, null)
+  )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+  val condition = {
+    And(
+      (left.col("a") === right.col("c")).expr,
+      LessThan(left.col("b").expr, right.col("d").expr))
+  }
+
+  // --- Basic outer joins 
------------------------------------------------------------------------
+
+  testOuterJoin(
+    "basic left outer join",
+    left,
+    right,
+    LeftOuter,
+    condition,
+    Seq(
+      (null, null, null, null),
+      (1, 2.0, null, null),
+      (2, 100.0, null, null),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (3, 3.0, null, null),
+      (5, 1.0, 5, 3.0),
+      (6, 6.0, null, null)
+    )
+  )
+
+  testOuterJoin(
+    "basic right outer join",
+    left,
+    right,
+    RightOuter,
+    condition,
+    Seq(
+      (null, null, null, null),
+      (null, null, 0, 0.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (null, null, 2, -1.0),
+      (null, null, 2, -1.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (null, null, 3, 2.0),
+      (null, null, 4, 1.0),
+      (5, 1.0, 5, 3.0),
+      (null, null, 7, 7.0)
+    )
+  )
+
+  testOuterJoin(
+    "basic full outer join",
+    left,
+    right,
+    FullOuter,
+    condition,
+    Seq(
+      (1, 2.0, null, null),
+      (null, null, 2, -1.0),
+      (null, null, 2, -1.0),
+      (2, 100.0, null, null),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (2, 1.0, 2, 3.0),
+      (3, 3.0, null, null),
+      (5, 1.0, 5, 3.0),
+      (6, 6.0, null, null),
+      (null, null, 0, 0.0),
+      (null, null, 3, 2.0),
+      (null, null, 4, 1.0),
+      (null, null, 7, 7.0),
+      (null, null, null, null),
+      (null, null, null, null)
+    )
+  )
+
+  // --- Both inputs empty 
------------------------------------------------------------------------
+
+  testOuterJoin(
+    "left outer join with both inputs empty",
+    left.filter("false"),
+    right.filter("false"),
+    LeftOuter,
+    condition,
+    Seq.empty
+  )
+
+  testOuterJoin(
+    "right outer join with both inputs empty",
+    left.filter("false"),
+    right.filter("false"),
+    RightOuter,
+    condition,
+    Seq.empty
+  )
+
+  testOuterJoin(
+    "full outer join with both inputs empty",
+    left.filter("false"),
+    right.filter("false"),
+    FullOuter,
+    condition,
+    Seq.empty
+  )
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 927e85a..4503ed2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -17,58 +17,91 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
+import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, 
SparkPlanTest}
 
+class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
 
-class SemiJoinSuite extends SparkPlanTest{
-  val left = Seq(
-    (1, 2.0),
-    (1, 2.0),
-    (2, 1.0),
-    (2, 1.0),
-    (3, 3.0)
-  ).toDF("a", "b")
+  private def testLeftSemiJoin(
+      testName: String,
+      leftRows: DataFrame,
+      rightRows: DataFrame,
+      condition: Expression,
+      expectedAnswer: Seq[Product]): Unit = {
+    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, 
Some(condition))
+    ExtractEquiJoinKeys.unapply(join).foreach {
+      case (joinType, leftKeys, rightKeys, boundCondition, leftChild, 
rightChild) =>
+        test(s"$testName using LeftSemiJoinHash") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              EnsureRequirements(left.sqlContext).apply(
+                LeftSemiJoinHash(leftKeys, rightKeys, left, right, 
boundCondition)),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
 
-  val right = Seq(
-    (2, 3.0),
-    (2, 3.0),
-    (3, 2.0),
-    (4, 1.0)
-  ).toDF("c", "d")
+        test(s"$testName using BroadcastLeftSemiJoinHash") {
+          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+              BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, 
boundCondition),
+              expectedAnswer.map(Row.fromTuple),
+              sortAnswers = true)
+          }
+        }
+    }
 
-  val leftKeys: List[Expression] = 'a :: Nil
-  val rightKeys: List[Expression] = 'c :: Nil
-  val condition = Some(LessThan('b, 'd))
-
-  test("left semi join hash") {
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
-      Seq(
-        (2, 1.0),
-        (2, 1.0)
-      ).map(Row.fromTuple))
+    test(s"$testName using LeftSemiJoinBNL") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) 
=>
+          LeftSemiJoinBNL(left, right, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
   }
 
-  test("left semi join BNL") {
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      LeftSemiJoinBNL(left, right, condition),
-      Seq(
-        (1, 2.0),
-        (1, 2.0),
-        (2, 1.0),
-        (2, 1.0)
-      ).map(Row.fromTuple))
-  }
+  val left = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+    Row(1, 2.0),
+    Row(1, 2.0),
+    Row(2, 1.0),
+    Row(2, 1.0),
+    Row(3, 3.0),
+    Row(null, null),
+    Row(null, 5.0),
+    Row(6, null)
+  )), new StructType().add("a", IntegerType).add("b", DoubleType))
 
-  test("broadcast left semi join hash") {
-    checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
-      BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
-      Seq(
-        (2, 1.0),
-        (2, 1.0)
-      ).map(Row.fromTuple))
+  val right = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+    Row(2, 3.0),
+    Row(2, 3.0),
+    Row(3, 2.0),
+    Row(4, 1.0),
+    Row(null, null),
+    Row(null, 5.0),
+    Row(6, null)
+  )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+  val condition = {
+    And(
+      (left.col("a") === right.col("c")).expr,
+      LessThan(left.col("b").expr, right.col("d").expr))
   }
+
+  testLeftSemiJoin(
+    "basic test",
+    left,
+    right,
+    condition,
+    Seq(
+      (2, 1.0),
+      (2, 1.0)
+    )
+  )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 4c11acd..1066695 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext
 import org.apache.spark.util.Utils
 
 trait SQLTestUtils { this: SparkFunSuite =>
-  def sqlContext: SQLContext
+  protected def sqlContext: SQLContext
 
   protected def configuration = sqlContext.sparkContext.hadoopConfiguration
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9beef99/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 567d7fa..f17177a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) 
with Logging {
       HashAggregation,
       Aggregation,
       LeftSemiJoin,
-      HashJoin,
+      EquiJoinSelection,
       BasicOperators,
       CartesianProduct,
       BroadcastNestedLoopJoin


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to