Repository: spark
Updated Branches:
  refs/heads/master 4754e16f4 -> 585638e81


[SPARK-2213] [SQL] sort merge join for spark sql

Thanks for the initial work from Ishiihara in #3173

This PR introduce a new join method of sort merge join, which firstly ensure 
that keys of same value are in the same partition, and inside each partition 
the Rows are sorted by key. Then we can run down both sides together, find 
matched rows using [sort merge 
join](http://en.wikipedia.org/wiki/Sort-merge_join). In this way, we don't have 
to store the whole hash table of one side as hash join, thus we have less 
memory usage. Also, this PR would benefit from #3438 , making the sorting 
phrase much more efficient.

We introduced a new configuration of "spark.sql.planner.sortMergeJoin" to 
switch between this(`true`) and ShuffledHashJoin(`false`), probably we want the 
default value of it be `false` at first.

Author: Daoyuan Wang <daoyuan.w...@intel.com>
Author: Michael Armbrust <mich...@databricks.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <mich...@databricks.com>

Closes #5208 from adrian-wang/smj and squashes the following commits:

2493b9f [Daoyuan Wang] fix style
5049d88 [Daoyuan Wang] propagate rowOrdering for RangePartitioning
f91a2ae [Daoyuan Wang] yin's comment: use external sort if option is enabled, 
add comments
f515cd2 [Daoyuan Wang] yin's comment: outputOrdering, join suite refine
ec8061b [Daoyuan Wang] minor change
413fd24 [Daoyuan Wang] Merge pull request #3 from marmbrus/pr/5208
952168a [Michael Armbrust] add type
5492884 [Michael Armbrust] copy when ordering
7ddd656 [Michael Armbrust] Cleanup addition of ordering requirements
b198278 [Daoyuan Wang] inherit ordering in project
c8e82a3 [Daoyuan Wang] fix style
6e897dd [Daoyuan Wang] hide boundReference from manually construct RowOrdering 
for key compare in smj
8681d73 [Daoyuan Wang] refactor Exchange and fix copy for sorting
2875ef2 [Daoyuan Wang] fix changed configuration
61d7f49 [Daoyuan Wang] add omitted comment
00a4430 [Daoyuan Wang] fix bug
078d69b [Daoyuan Wang] address comments: add comments, do sort in shuffle, and 
others
3af6ba5 [Daoyuan Wang] use buffer for only one side
171001f [Daoyuan Wang] change default outputordering
47455c9 [Daoyuan Wang] add apache license ...
a28277f [Daoyuan Wang] fix style
645c70b [Daoyuan Wang] address comments using sort
068c35d [Daoyuan Wang] fix new style and add some tests
925203b [Daoyuan Wang] address comments
07ce92f [Daoyuan Wang] fix ArrayIndexOutOfBound
42fca0e [Daoyuan Wang] code clean
e3ec096 [Daoyuan Wang] fix comment style..
2edd235 [Daoyuan Wang] fix outputpartitioning
57baa40 [Daoyuan Wang] fix sort eval bug
303b6da [Daoyuan Wang] fix several errors
95db7ad [Daoyuan Wang] fix brackets for if-statement
4464f16 [Daoyuan Wang] fix error
880d8e9 [Daoyuan Wang] sort merge join for spark sql


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/585638e8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/585638e8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/585638e8

Branch: refs/heads/master
Commit: 585638e81ce09a72b9e7f95d38e0d432cfa02456
Parents: 4754e16
Author: Daoyuan Wang <daoyuan.w...@intel.com>
Authored: Wed Apr 15 14:06:10 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Wed Apr 15 14:06:10 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/rows.scala   |  10 +-
 .../catalyst/plans/physical/partitioning.scala  |  13 ++
 .../scala/org/apache/spark/sql/SQLConf.scala    |   8 +
 .../scala/org/apache/spark/sql/SQLContext.scala |   2 +-
 .../apache/spark/sql/execution/Exchange.scala   | 148 +++++++++++++---
 .../apache/spark/sql/execution/SparkPlan.scala  |   6 +
 .../spark/sql/execution/SparkStrategies.scala   |  11 +-
 .../spark/sql/execution/basicOperators.scala    |  10 ++
 .../sql/execution/joins/SortMergeJoin.scala     | 169 +++++++++++++++++++
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  28 ++-
 .../execution/SortMergeCompatibilitySuite.scala | 162 ++++++++++++++++++
 11 files changed, 534 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 1b62e17..b6ec7d3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
-
+import org.apache.spark.sql.types.{UTF8String, DataType, StructType, 
NativeType}
 
 /**
  * An extended interface to [[Row]] that allows the values for each column to 
be updated.  Setting
@@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends 
Ordering[Row] {
     return 0
   }
 }
+
+object RowOrdering {
+  def forSchema(dataTypes: Seq[DataType]): RowOrdering =
+    new RowOrdering(dataTypes.zipWithIndex.map {
+      case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = 
true), Ascending)
+    })
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 288c11f..fb4217a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -94,6 +94,9 @@ sealed trait Partitioning {
    * only compatible if the `numPartitions` of them is the same.
    */
   def compatibleWith(other: Partitioning): Boolean
+
+  /** Returns the expressions that are used to key the partitioning. */
+  def keyExpressions: Seq[Expression]
 }
 
 case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends 
Partitioning {
     case UnknownPartitioning(_) => true
     case _ => false
   }
+
+  override def keyExpressions: Seq[Expression] = Nil
 }
 
 case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
     case SinglePartition => true
     case _ => false
   }
+
+  override def keyExpressions: Seq[Expression] = Nil
 }
 
 case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
     case SinglePartition => true
     case _ => false
   }
+
+  override def keyExpressions: Seq[Expression] = Nil
 }
 
 /**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     case _ => false
   }
 
+  override def keyExpressions: Seq[Expression] = expressions
+
   override def eval(input: Row = null): EvaluatedType =
     throw new TreeNodeException(this, s"No function to evaluate expression. 
type: ${this.nodeName}")
 }
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
     case _ => false
   }
 
+  override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+
   override def eval(input: Row): EvaluatedType =
     throw new TreeNodeException(this, s"No function to evaluate expression. 
type: ${this.nodeName}")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ee641bd..5c65f04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -47,6 +47,7 @@ private[spark] object SQLConf {
   // Options that control which operators can be chosen by the query planner.  
These should be
   // considered hints and may be ignored by future versions of Spark SQL.
   val EXTERNAL_SORT = "spark.sql.planner.externalSort"
+  val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
 
   // This is only used for the thriftserver
   val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -129,6 +130,13 @@ private[sql] class SQLConf extends Serializable {
   private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, 
"false").toBoolean
 
   /**
+   * Sort merge join would sort the two side of join first, and then iterate 
both sides together
+   * only once to get all matches. Using sort merge join can save a lot of 
memory usage compared
+   * to HashJoin.
+   */
+  private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, 
"false").toBoolean
+
+  /**
    * When set to true, Spark SQL will use the Scala compiler at runtime to 
generate custom bytecode
    * that evaluates expressions found in queries.  In general this custom code 
runs much faster
    * than interpreted evaluation, but there are significant start-up costs due 
to compilation.

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 89a4faf..f9f3eb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: 
SparkContext)
   @transient
   protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
     val batches =
-      Batch("Add exchange", Once, AddExchange(self)) :: Nil
+      Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
   }
 
   protected[sql] def openSession(): SQLSession = {

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 437408d..518fc9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -19,24 +19,42 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, 
SparkConf}
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
 import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.util.MutablePair
 
+object Exchange {
+  /**
+   * Returns true when the ordering expressions are a subset of the key.
+   * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within 
[[Exchange]].
+   */
+  def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: 
Seq[SortOrder]): Boolean = {
+    
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
+  }
+}
+
 /**
  * :: DeveloperApi ::
+ * Performs a shuffle that will result in the desired `newPartitioning`.  
Optionally sorts each
+ * resulting partition based on expressions from the partition key.  It is 
invalid to construct an
+ * exchange operator with a `newOrdering` that cannot be calculated using the 
partitioning key.
  */
 @DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends 
UnaryNode {
+case class Exchange(
+    newPartitioning: Partitioning,
+    newOrdering: Seq[SortOrder],
+    child: SparkPlan)
+  extends UnaryNode {
 
   override def outputPartitioning: Partitioning = newPartitioning
 
+  override def outputOrdering: Seq[SortOrder] = newOrdering
+
   override def output: Seq[Attribute] = child.output
 
   /** We must copy rows when sort based shuffle is on */
@@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
   private val bypassMergeThreshold =
     
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold",
 200)
 
+  private val keyOrdering = {
+    if (newOrdering.nonEmpty) {
+      val key = newPartitioning.keyExpressions
+      val boundOrdering = newOrdering.map { o =>
+        val ordinal = key.indexOf(o.child)
+        if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for 
$newPartitioning")
+        o.copy(child = BoundReference(ordinal, o.child.dataType, 
o.child.nullable))
+      }
+      new RowOrdering(boundOrdering)
+    } else {
+      null // Ordering will not be used
+    }
+  }
+
   override def execute(): RDD[Row] = attachTree(this , "execute") {
     newPartitioning match {
       case HashPartitioning(expressions, numPartitions) =>
@@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
         // we can avoid the defensive copies to improve performance. In the 
long run, we probably
         // want to include information in shuffle dependencies to indicate 
whether elements in the
         // source RDD should be copied.
-        val rdd = if (sortBasedShuffleOn && numPartitions > 
bypassMergeThreshold) {
+        val willMergeSort = sortBasedShuffleOn && numPartitions > 
bypassMergeThreshold
+
+        val rdd = if (willMergeSort || newOrdering.nonEmpty) {
           child.execute().mapPartitions { iter =>
             val hashExpressions = newMutableProjection(expressions, 
child.output)()
             iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
           }
         }
         val part = new HashPartitioner(numPartitions)
-        val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
+        val shuffled =
+          if (newOrdering.nonEmpty) {
+            new ShuffledRDD[Row, Row, Row](rdd, 
part).setKeyOrdering(keyOrdering)
+          } else {
+            new ShuffledRDD[Row, Row, Row](rdd, part)
+          }
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
         shuffled.map(_._2)
 
       case RangePartitioning(sortingExpressions, numPartitions) =>
-        val rdd = if (sortBasedShuffleOn) {
+        val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
           child.execute().mapPartitions { iter => iter.map(row => (row.copy(), 
null))}
         } else {
           child.execute().mapPartitions { iter =>
@@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
         implicit val ordering = new RowOrdering(sortingExpressions, 
child.output)
 
         val part = new RangePartitioner(numPartitions, rdd, ascending = true)
-        val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
+        val shuffled =
+          if (newOrdering.nonEmpty) {
+            new ShuffledRDD[Row, Null, Null](rdd, 
part).setKeyOrdering(keyOrdering)
+          } else {
+            new ShuffledRDD[Row, Null, Null](rdd, part)
+          }
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
 
         shuffled.map(_._1)
@@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: 
SparkPlan) extends Una
  * Ensures that the 
[[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
  * of input data meets the
  * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] 
requirements for
- * each operator by inserting [[Exchange]] Operators where required.
+ * each operator by inserting [[Exchange]] Operators where required.  Also 
ensure that the
+ * required input partition ordering requirements are met.
  */
-private[sql] case class AddExchange(sqlContext: SQLContext) extends 
Rule[SparkPlan] {
+private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends 
Rule[SparkPlan] {
   // TODO: Determine the number of partitions.
   def numPartitions: Int = sqlContext.conf.numShufflePartitions
 
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
     case operator: SparkPlan =>
-      // Check if every child's outputPartitioning satisfies the corresponding
+      // True iff every child's outputPartitioning satisfies the corresponding
       // required data distribution.
       def meetsRequirements: Boolean =
-        !operator.requiredChildDistribution.zip(operator.children).map {
+        operator.requiredChildDistribution.zip(operator.children).forall {
           case (required, child) =>
             val valid = child.outputPartitioning.satisfies(required)
             logDebug(
               s"${if (valid) "Valid" else "Invalid"} distribution," +
                 s"required: $required current: ${child.outputPartitioning}")
             valid
-        }.exists(!_)
+        }
 
-      // Check if outputPartitionings of children are compatible with each 
other.
+      // True iff any of the children are incorrectly sorted.
+      def needsAnySort: Boolean =
+        operator.requiredChildOrdering.zip(operator.children).exists {
+          case (required, child) => required.nonEmpty && required != 
child.outputOrdering
+        }
+
+      // True iff outputPartitionings of children are compatible with each 
other.
       // It is possible that every child satisfies its required data 
distribution
       // but two children have incompatible outputPartitionings. For example,
       // A dataset is range partitioned by "a.asc" (RangePartitioning) and 
another
@@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: 
SQLContext) extends Rule[SparkPl
             case Seq(a,b) => a compatibleWith b
           }.exists(!_)
 
-      // Check if the partitioning we want to ensure is the same as the 
child's output
-      // partitioning. If so, we do not need to add the Exchange operator.
-      def addExchangeIfNecessary(partitioning: Partitioning, child: 
SparkPlan): SparkPlan =
-        if (child.outputPartitioning != partitioning) Exchange(partitioning, 
child) else child
+      // Adds Exchange or Sort operators as required
+      def addOperatorsIfNecessary(
+          partitioning: Partitioning,
+          rowOrdering: Seq[SortOrder],
+          child: SparkPlan): SparkPlan = {
+        val needSort = rowOrdering.nonEmpty && child.outputOrdering != 
rowOrdering
+        val needsShuffle = child.outputPartitioning != partitioning
+        val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, 
rowOrdering)
+
+        if (needSort && needsShuffle && canSortWithShuffle) {
+          Exchange(partitioning, rowOrdering, child)
+        } else {
+          val withShuffle = if (needsShuffle) {
+            Exchange(partitioning, Nil, child)
+          } else {
+            child
+          }
 
-      if (meetsRequirements && compatible) {
+          val withSort = if (needSort) {
+            if (sqlContext.conf.externalSortEnabled) {
+              ExternalSort(rowOrdering, global = false, withShuffle)
+            } else {
+              Sort(rowOrdering, global = false, withShuffle)
+            }
+          } else {
+            withShuffle
+          }
+
+          withSort
+        }
+      }
+
+      if (meetsRequirements && compatible && !needsAnySort) {
         operator
       } else {
         // At least one child does not satisfies its required data 
distribution or
         // at least one child's outputPartitioning is not compatible with 
another child's
         // outputPartitioning. In this case, we need to add Exchange operators.
-        val repartitionedChildren = 
operator.requiredChildDistribution.zip(operator.children).map {
-          case (AllTuples, child) =>
-            addExchangeIfNecessary(SinglePartition, child)
-          case (ClusteredDistribution(clustering), child) =>
-            addExchangeIfNecessary(HashPartitioning(clustering, 
numPartitions), child)
-          case (OrderedDistribution(ordering), child) =>
-            addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), 
child)
-          case (UnspecifiedDistribution, child) => child
-          case (dist, _) => sys.error(s"Don't know how to ensure $dist")
+        val requirements =
+          (operator.requiredChildDistribution, operator.requiredChildOrdering, 
operator.children)
+
+        val fixedChildren = requirements.zipped.map {
+          case (AllTuples, rowOrdering, child) =>
+            addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
+          case (ClusteredDistribution(clustering), rowOrdering, child) =>
+            addOperatorsIfNecessary(HashPartitioning(clustering, 
numPartitions), rowOrdering, child)
+          case (OrderedDistribution(ordering), rowOrdering, child) =>
+            addOperatorsIfNecessary(RangePartitioning(ordering, 
numPartitions), rowOrdering, child)
+
+          case (UnspecifiedDistribution, Seq(), child) =>
+            child
+          case (UnspecifiedDistribution, rowOrdering, child) =>
+            if (sqlContext.conf.externalSortEnabled) {
+              ExternalSort(rowOrdering, global = false, child)
+            } else {
+              Sort(rowOrdering, global = false, child)
+            }
+
+          case (dist, ordering, _) =>
+            sys.error(s"Don't know how to ensure $dist with ordering 
$ordering")
         }
-        operator.withNewChildren(repartitionedChildren)
+
+        operator.withNewChildren(fixedChildren)
       }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index fabcf6b..e159ffe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging with Serializ
   def requiredChildDistribution: Seq[Distribution] =
     Seq.fill(children.size)(UnspecifiedDistribution)
 
+  /** Specifies how data is ordered in each partition. */
+  def outputOrdering: Seq[SortOrder] = Nil
+
+  /** Specifies sort order for each partition requirements on the input data 
for this operator. */
+  def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
+
   /**
    * Runs this query returning the result as an RDD.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5b99e40..e687d01 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
            left.statistics.sizeInBytes <= 
sqlContext.conf.autoBroadcastJoinThreshold =>
           makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, 
joins.BuildLeft)
 
+      // If the sort merge join option is set, we want to use sort merge join 
prior to hashjoin
+      // for now let's support inner join first, then add outer join
+      case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, 
right)
+        if sqlContext.conf.sortMergeJoinEnabled =>
+        val mergeJoin =
+          joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), 
planLater(right))
+        condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, 
right) =>
         val buildSide =
           if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
@@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.OneRowRelation =>
         execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
       case logical.Repartition(expressions, child) =>
-        execution.Exchange(HashPartitioning(expressions, numPartitions), 
planLater(child)) :: Nil
+        execution.Exchange(
+          HashPartitioning(expressions, numPartitions), Nil, planLater(child)) 
:: Nil
       case e @ EvaluatePython(udf, child, _) =>
         BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f8221f4..308dae2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: 
SparkPlan) extends
     val resuableProjection = buildProjection()
     iter.map(resuableProjection)
   }
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 }
 
 /**
@@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) 
extends UnaryNode {
   override def execute(): RDD[Row] = child.execute().mapPartitions { iter =>
     iter.filter(conditionEvaluator)
   }
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 }
 
 /**
@@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: 
Seq[SortOrder], child: SparkPlan)
   // TODO: Terminal split should be implemented differently from non-terminal 
split.
   // TODO: Pick num splits based on |limit|.
   override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
+
+  override def outputOrdering: Seq[SortOrder] = sortOrder
 }
 
 /**
@@ -172,6 +178,8 @@ case class Sort(
   }
 
   override def output: Seq[Attribute] = child.output
+
+  override def outputOrdering: Seq[SortOrder] = sortOrder
 }
 
 /**
@@ -202,6 +210,8 @@ case class ExternalSort(
   }
 
   override def output: Seq[Attribute] = child.output
+
+  override def outputOrdering: Seq[SortOrder] = sortOrder
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
new file mode 100644
index 0000000..b512366
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge join of two child relations.
+ */
+@DeveloperApi
+case class SortMergeJoin(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode {
+
+  override def output: Seq[Attribute] = left.output ++ right.output
+
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  // this is to manually construct an ordering that can be used to compare 
keys from both sides
+  private val keyOrdering: RowOrdering = 
RowOrdering.forSchema(leftKeys.map(_.dataType))
+
+  override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+  @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, 
left.output)
+  @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, 
right.output)
+
+  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
+    keys.map(SortOrder(_, Ascending))
+
+  override def execute(): RDD[Row] = {
+    val leftResults = left.execute().map(_.copy())
+    val rightResults = right.execute().map(_.copy())
+
+    leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
+      new Iterator[Row] {
+        // Mutable per row objects.
+        private[this] val joinRow = new JoinedRow5
+        private[this] var leftElement: Row = _
+        private[this] var rightElement: Row = _
+        private[this] var leftKey: Row = _
+        private[this] var rightKey: Row = _
+        private[this] var rightMatches: CompactBuffer[Row] = _
+        private[this] var rightPosition: Int = -1
+        private[this] var stop: Boolean = false
+        private[this] var matchKey: Row = _
+
+        // initialize iterator
+        initialize()
+
+        override final def hasNext: Boolean = nextMatchingPair()
+
+        override final def next(): Row = {
+          if (hasNext) {
+            // we are using the buffered right rows and run down left iterator
+            val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
+            rightPosition += 1
+            if (rightPosition >= rightMatches.size) {
+              rightPosition = 0
+              fetchLeft()
+              if (leftElement == null || keyOrdering.compare(leftKey, 
matchKey) != 0) {
+                stop = false
+                rightMatches = null
+              }
+            }
+            joinedRow
+          } else {
+            // no more result
+            throw new NoSuchElementException
+          }
+        }
+
+        private def fetchLeft() = {
+          if (leftIter.hasNext) {
+            leftElement = leftIter.next()
+            leftKey = leftKeyGenerator(leftElement)
+          } else {
+            leftElement = null
+          }
+        }
+
+        private def fetchRight() = {
+          if (rightIter.hasNext) {
+            rightElement = rightIter.next()
+            rightKey = rightKeyGenerator(rightElement)
+          } else {
+            rightElement = null
+          }
+        }
+
+        private def initialize() = {
+          fetchLeft()
+          fetchRight()
+        }
+
+        /**
+         * Searches the right iterator for the next rows that have matches in 
left side, and store
+         * them in a buffer.
+         *
+         * @return true if the search is successful, and false if the right 
iterator runs out of
+         *         tuples.
+         */
+        private def nextMatchingPair(): Boolean = {
+          if (!stop && rightElement != null) {
+            // run both side to get the first match pair
+            while (!stop && leftElement != null && rightElement != null) {
+              val comparing = keyOrdering.compare(leftKey, rightKey)
+              // for inner join, we need to filter those null keys
+              stop = comparing == 0 && !leftKey.anyNull
+              if (comparing > 0 || rightKey.anyNull) {
+                fetchRight()
+              } else if (comparing < 0 || leftKey.anyNull) {
+                fetchLeft()
+              }
+            }
+            rightMatches = new CompactBuffer[Row]()
+            if (stop) {
+              stop = false
+              // iterate the right side to buffer all rows that matches
+              // as the records should be ordered, exit when we meet the first 
that not match
+              while (!stop && rightElement != null) {
+                rightMatches += rightElement
+                fetchRight()
+                stop = keyOrdering.compare(leftKey, rightKey) != 0
+              }
+              if (rightMatches.size > 0) {
+                rightPosition = 0
+                matchKey = leftKey
+              }
+            }
+          }
+          rightMatches != null && rightMatches.size > 0
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e4dee87..037d392 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
       case j: BroadcastLeftSemiJoinHash => j
+      case j: SortMergeJoin => j
     }
 
     assert(operators.size === 1)
@@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
   test("join operator selection") {
     cacheManager.clearCache()
 
+    val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
     Seq(
       ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", 
classOf[LeftSemiJoinHash]),
       ("SELECT * FROM testData LEFT SEMI JOIN testData2", 
classOf[LeftSemiJoinBNL]),
@@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
         classOf[BroadcastNestedLoopJoin])
     ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    try {
+      conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+      Seq(
+        ("SELECT * FROM testData JOIN testData2 ON key = a", 
classOf[SortMergeJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", 
classOf[SortMergeJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", 
classOf[SortMergeJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    } finally {
+      conf.setConf("spark.sql.planner.sortMergeJoin", 
SORTMERGEJOIN_ENABLED.toString)
+    }
   }
 
   test("broadcasted hash join operator selection") {
     cacheManager.clearCache()
     sql("CACHE TABLE testData")
 
+    val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
     Seq(
       ("SELECT * FROM testData join testData2 ON key = a", 
classOf[BroadcastHashJoin]),
       ("SELECT * FROM testData join testData2 ON key = a and key = 2", 
classOf[BroadcastHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a where key = 2", 
classOf[BroadcastHashJoin])
+      ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+        classOf[BroadcastHashJoin])
     ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    try {
+      conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+      Seq(
+        ("SELECT * FROM testData join testData2 ON key = a", 
classOf[BroadcastHashJoin]),
+        ("SELECT * FROM testData join testData2 ON key = a and key = 2",
+          classOf[BroadcastHashJoin]),
+        ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+          classOf[BroadcastHashJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    } finally {
+      conf.setConf("spark.sql.planner.sortMergeJoin", 
SORTMERGEJOIN_ENABLED.toString)
+    }
 
     sql("UNCACHE TABLE testData")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/585638e8/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
new file mode 100644
index 0000000..65d070b
--- /dev/null
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.hive.test.TestHive
+
+/**
+ * Runs the test cases that are included in the hive distribution with sort 
merge join is true.
+ */
+class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
+  override def beforeAll() {
+    super.beforeAll()
+    TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
+  }
+
+  override def afterAll() {
+    TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
+    super.afterAll()
+  }
+
+  override def whiteList = Seq(
+    "auto_join0",
+    "auto_join1",
+    "auto_join10",
+    "auto_join11",
+    "auto_join12",
+    "auto_join13",
+    "auto_join14",
+    "auto_join14_hadoop20",
+    "auto_join15",
+    "auto_join17",
+    "auto_join18",
+    "auto_join19",
+    "auto_join2",
+    "auto_join20",
+    "auto_join21",
+    "auto_join22",
+    "auto_join23",
+    "auto_join24",
+    "auto_join25",
+    "auto_join26",
+    "auto_join27",
+    "auto_join28",
+    "auto_join3",
+    "auto_join30",
+    "auto_join31",
+    "auto_join32",
+    "auto_join4",
+    "auto_join5",
+    "auto_join6",
+    "auto_join7",
+    "auto_join8",
+    "auto_join9",
+    "auto_join_filters",
+    "auto_join_nulls",
+    "auto_join_reordering_values",
+    "auto_smb_mapjoin_14",
+    "auto_sortmerge_join_1",
+    "auto_sortmerge_join_10",
+    "auto_sortmerge_join_11",
+    "auto_sortmerge_join_12",
+    "auto_sortmerge_join_13",
+    "auto_sortmerge_join_14",
+    "auto_sortmerge_join_15",
+    "auto_sortmerge_join_16",
+    "auto_sortmerge_join_2",
+    "auto_sortmerge_join_3",
+    "auto_sortmerge_join_4",
+    "auto_sortmerge_join_5",
+    "auto_sortmerge_join_6",
+    "auto_sortmerge_join_7",
+    "auto_sortmerge_join_8",
+    "auto_sortmerge_join_9",
+    "correlationoptimizer1",
+    "correlationoptimizer10",
+    "correlationoptimizer11",
+    "correlationoptimizer13",
+    "correlationoptimizer14",
+    "correlationoptimizer15",
+    "correlationoptimizer2",
+    "correlationoptimizer3",
+    "correlationoptimizer4",
+    "correlationoptimizer6",
+    "correlationoptimizer7",
+    "correlationoptimizer8",
+    "correlationoptimizer9",
+    "join0",
+    "join1",
+    "join10",
+    "join11",
+    "join12",
+    "join13",
+    "join14",
+    "join14_hadoop20",
+    "join15",
+    "join16",
+    "join17",
+    "join18",
+    "join19",
+    "join2",
+    "join20",
+    "join21",
+    "join22",
+    "join23",
+    "join24",
+    "join25",
+    "join26",
+    "join27",
+    "join28",
+    "join29",
+    "join3",
+    "join30",
+    "join31",
+    "join32",
+    "join32_lessSize",
+    "join33",
+    "join34",
+    "join35",
+    "join36",
+    "join37",
+    "join38",
+    "join39",
+    "join4",
+    "join40",
+    "join41",
+    "join5",
+    "join6",
+    "join7",
+    "join8",
+    "join9",
+    "join_1to1",
+    "join_array",
+    "join_casesensitive",
+    "join_empty",
+    "join_filters",
+    "join_hive_626",
+    "join_map_ppr",
+    "join_nulls",
+    "join_nullsafe",
+    "join_rc",
+    "join_reorder2",
+    "join_reorder3",
+    "join_reorder4",
+    "join_star"
+  )
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to