Repository: spark
Updated Branches:
  refs/heads/master 30e89111d -> 114ff926f


[SPARK-2205] [SQL] Avoid unnecessary exchange operators in multi-way joins

This PR adds `PartitioningCollection`, which is used to represent the 
`outputPartitioning` for SparkPlans with multiple children (e.g. 
`ShuffledHashJoin`). So, a `SparkPlan` can have multiple descriptions of its 
partitioning schemes. Taking `ShuffledHashJoin` as an example, it has two 
descriptions of its partitioning schemes, i.e. `left.outputPartitioning` and 
`right.outputPartitioning`. So when we have a query like `select * from t1 join 
t2 on (t1.x = t2.x) join t3 on (t2.x = t3.x)` will only have three Exchange 
operators (when shuffled joins are needed) instead of four.

The code in this PR was authored by yhuai; I'm opening this PR to factor out 
this change from #7685, a larger pull request which contains two other 
optimizations.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png"; height=40 alt="Review on 
Reviewable"/>](https://reviewable.io/reviews/apache/spark/7773)
<!-- Reviewable:end -->

Author: Yin Huai <yh...@databricks.com>
Author: Josh Rosen <joshro...@databricks.com>

Closes #7773 from JoshRosen/multi-way-join-planning-improvements and squashes 
the following commits:

5c45924 [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
multi-way-join-planning-improvements
cd8269b [Josh Rosen] Refactor test to use SQLTestUtils
2963857 [Yin Huai] Revert unnecessary SqlConf change.
73913f7 [Yin Huai] Add comments and test. Also, revert the change in 
ShuffledHashOuterJoin for now.
4a99204 [Josh Rosen] Delete unrelated expression change
884ab95 [Josh Rosen] Carve out only SPARK-2205 changes.
247e5fa [Josh Rosen] Merge remote-tracking branch 'origin/master' into 
multi-way-join-planning-improvements
c57a954 [Yin Huai] Bug fix.
d3d2e64 [Yin Huai] First round of cleanup.
f9516b0 [Yin Huai] Style
c6667e7 [Yin Huai] Add PartitioningCollection.
e616d3b [Yin Huai] wip
7c2d2d8 [Yin Huai] Bug fix and refactoring.
69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and 
NullUnsafePartitioning.
d5b84c3 [Yin Huai] Do not add unnessary filters.
2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/114ff926
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/114ff926
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/114ff926

Branch: refs/heads/master
Commit: 114ff926fcd078697c1111279b5cf6173b515865
Parents: 30e8911
Author: Yin Huai <yh...@databricks.com>
Authored: Sun Aug 2 20:44:23 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Sun Aug 2 20:44:23 2015 -0700

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  | 87 +++++++++++++++++---
 .../spark/sql/catalyst/DistributionSuite.scala  |  2 +-
 .../apache/spark/sql/execution/Exchange.scala   |  2 +-
 .../joins/BroadcastHashOuterJoin.scala          |  4 +-
 .../sql/execution/joins/HashOuterJoin.scala     |  9 --
 .../sql/execution/joins/LeftSemiJoinHash.scala  |  6 +-
 .../sql/execution/joins/ShuffledHashJoin.scala  |  7 +-
 .../execution/joins/ShuffledHashOuterJoin.scala | 10 ++-
 .../sql/execution/joins/SortMergeJoin.scala     |  3 +-
 .../spark/sql/execution/PlannerSuite.scala      | 49 ++++++++++-
 10 files changed, 148 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f4d1dba..ec659ce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) 
extends Distributi
 /**
  * Represents data where tuples have been ordered according to the `ordering`
  * [[Expression Expressions]].  This is a strictly stronger guarantee than
- * [[ClusteredDistribution]] as an ordering will ensure that tuples that share 
the same value for
- * the ordering expressions are contiguous and will never be split across 
partitions.
+ * [[ClusteredDistribution]] as an ordering will ensure that tuples that share 
the
+ * same value for the ordering expressions are contiguous and will never be 
split across
+ * partitions.
  */
 case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
   require(
@@ -86,8 +87,12 @@ sealed trait Partitioning {
    */
   def satisfies(required: Distribution): Boolean
 
-  /** Returns the expressions that are used to key the partitioning. */
-  def keyExpressions: Seq[Expression]
+  /**
+   * Returns true iff we can say that the partitioning scheme of this 
[[Partitioning]]
+   * guarantees the same partitioning scheme described by `other`.
+   */
+  // TODO: Add an example once we have the `nullSafe` concept.
+  def guarantees(other: Partitioning): Boolean
 }
 
 case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends 
Partitioning {
     case _ => false
   }
 
-  override def keyExpressions: Seq[Expression] = Nil
+  override def guarantees(other: Partitioning): Boolean = false
 }
 
 case object SinglePartition extends Partitioning {
@@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning {
 
   override def satisfies(required: Distribution): Boolean = true
 
-  override def keyExpressions: Seq[Expression] = Nil
+  override def guarantees(other: Partitioning): Boolean = other match {
+    case SinglePartition => true
+    case _ => false
+  }
 }
 
 case object BroadcastPartitioning extends Partitioning {
@@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning {
 
   override def satisfies(required: Distribution): Boolean = true
 
-  override def keyExpressions: Seq[Expression] = Nil
+  override def guarantees(other: Partitioning): Boolean = other match {
+    case BroadcastPartitioning => true
+    case _ => false
+  }
 }
 
 /**
@@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
   override def nullable: Boolean = false
   override def dataType: DataType = IntegerType
 
-  private[this] lazy val clusteringSet = expressions.toSet
+  lazy val clusteringSet = expressions.toSet
 
   override def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
@@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     case _ => false
   }
 
-  override def keyExpressions: Seq[Expression] = expressions
+  override def guarantees(other: Partitioning): Boolean = other match {
+    case o: HashPartitioning =>
+      this.clusteringSet == o.clusteringSet && this.numPartitions == 
o.numPartitions
+    case _ => false
+  }
 }
 
 /**
@@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
     case _ => false
   }
 
-  override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+  override def guarantees(other: Partitioning): Boolean = other match {
+    case o: RangePartitioning => this == o
+    case _ => false
+  }
+}
+
+/**
+ * A collection of [[Partitioning]]s that can be used to describe the 
partitioning
+ * scheme of the output of a physical operator. It is usually used for an 
operator
+ * that has multiple children. In this case, a [[Partitioning]] in this 
collection
+ * describes how this operator's output is partitioned based on expressions 
from
+ * a child. For example, for a Join operator on two tables `A` and `B`
+ * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning 
schema,
+ * there are two [[Partitioning]]s can be used to describe how the output of
+ * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
+ * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
+ * in this collection do not need to be equivalent, which is useful for
+ * Outer Join operators.
+ */
+case class PartitioningCollection(partitionings: Seq[Partitioning])
+  extends Expression with Partitioning with Unevaluable {
+
+  require(
+    partitionings.map(_.numPartitions).distinct.length == 1,
+    s"PartitioningCollection requires all of its partitionings have the same 
numPartitions.")
+
+  override def children: Seq[Expression] = partitionings.collect {
+    case expr: Expression => expr
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = IntegerType
+
+  override val numPartitions = partitionings.map(_.numPartitions).distinct.head
+
+  /**
+   * Returns true if any `partitioning` of this collection satisfies the given
+   * [[Distribution]].
+   */
+  override def satisfies(required: Distribution): Boolean =
+    partitionings.exists(_.satisfies(required))
+
+  /**
+   * Returns true if any `partitioning` of this collection guarantees
+   * the given [[Partitioning]].
+   */
+  override def guarantees(other: Partitioning): Boolean =
+    partitionings.exists(_.guarantees(other))
+
+  override def toString: String = {
+    partitionings.map(_.toString).mkString("(", " or ", ")")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index c046dbf..827f7ce 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite {
     }
   }
 
-  test("HashPartitioning is the output partitioning") {
+  test("HashPartitioning (with nullSafe = true) is the output partitioning") {
     // Cases which do not need an exchange between two data properties.
     checkSatisfied(
       HashPartitioning(Seq('a, 'b, 'c), 10),

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 6bd57f0..05b009d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -209,7 +209,7 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
           child: SparkPlan): SparkPlan = {
 
         def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
-          if (child.outputPartitioning != partitioning) {
+          if (!child.outputPartitioning.guarantees(partitioning)) {
             Exchange(partitioning, child)
           } else {
             child

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 77e7fe7..309716a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, 
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
Distribution, UnspecifiedDistribution}
 import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.util.ThreadUtils
@@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin(
   override def requiredChildDistribution: Seq[Distribution] =
     UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
 
+  override def outputPartitioning: Partitioning = 
streamedPlan.outputPartitioning
+
   @transient
   private val broadcastFuture = future {
     // Note that we use .execute().collect() because we don't want to convert 
data to Scala types

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 7e671e7..a323aea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -22,7 +22,6 @@ import java.util.{HashMap => JavaHashMap}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.util.collection.CompactBuffer
@@ -38,14 +37,6 @@ trait HashOuterJoin {
   val left: SparkPlan
   val right: SparkPlan
 
-  override def outputPartitioning: Partitioning = joinType match {
-    case LeftOuter => left.outputPartitioning
-    case RightOuter => right.outputPartitioning
-    case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
-    case x =>
-      throw new IllegalArgumentException(s"HashOuterJoin should not take $x as 
the JoinType")
-  }
-
   override def output: Seq[Attribute] = {
     joinType match {
       case LeftOuter =>

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 26a6641..68ccd34 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
Distribution, ClusteredDistribution}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 
 /**
@@ -37,7 +37,9 @@ case class LeftSemiJoinHash(
     right: SparkPlan,
     condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
 
-  override def requiredChildDistribution: Seq[ClusteredDistribution] =
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   protected override def doExecute(): RDD[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10..fc6efe8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 
 /**
@@ -38,9 +38,10 @@ case class ShuffledHashJoin(
     right: SparkPlan)
   extends BinaryNode with HashJoin {
 
-  override def outputPartitioning: Partitioning = left.outputPartitioning
+  override def outputPartitioning: Partitioning =
+    PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
 
-  override def requiredChildDistribution: Seq[ClusteredDistribution] =
+  override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   protected override def doExecute(): RDD[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index d29b593..eee8ad8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, 
ClusteredDistribution}
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, 
RightOuter}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 
@@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin(
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
+  override def outputPartitioning: Partitioning = joinType match {
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case x =>
+      throw new IllegalArgumentException(s"HashOuterJoin should not take $x as 
the JoinType")
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val joinedRow = new JoinedRow()
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index bb18b54..41be78a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -40,7 +40,8 @@ case class SortMergeJoin(
 
   override def output: Seq[Attribute] = left.output ++ right.output
 
-  override def outputPartitioning: Partitioning = left.outputPartitioning
+  override def outputPartitioning: Partitioning =
+    PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
 
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/114ff926/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 845ce66..18b0e54 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, 
ShuffledHashJoin}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
 import org.apache.spark.sql.test.TestSQLContext._
 import org.apache.spark.sql.test.TestSQLContext.implicits._
 import org.apache.spark.sql.test.TestSQLContext.planner._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Row, SQLConf, execution}
+import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
 
 
-class PlannerSuite extends SparkFunSuite {
+class PlannerSuite extends SparkFunSuite with SQLTestUtils {
+
+  override def sqlContext: SQLContext = TestSQLContext
+
   private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
     val plannedOption = 
HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
     val planned =
@@ -157,4 +161,45 @@ class PlannerSuite extends SparkFunSuite {
     val planned = planner.TakeOrderedAndProject(query)
     assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
   }
+
+  test("PartitioningCollection") {
+    withTempTable("normal", "small", "tiny") {
+      testData.registerTempTable("normal")
+      testData.limit(10).registerTempTable("small")
+      testData.limit(3).registerTempTable("tiny")
+
+      // Disable broadcast join
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+        {
+          val numExchanges = sql(
+            """
+              |SELECT *
+              |FROM
+              |  normal JOIN small ON (normal.key = small.key)
+              |  JOIN tiny ON (small.key = tiny.key)
+            """.stripMargin
+          ).queryExecution.executedPlan.collect {
+            case exchange: Exchange => exchange
+          }.length
+          assert(numExchanges === 3)
+        }
+
+        {
+          // This second query joins on different keys:
+          val numExchanges = sql(
+            """
+              |SELECT *
+              |FROM
+              |  normal JOIN small ON (normal.key = small.key)
+              |  JOIN tiny ON (normal.key = tiny.key)
+            """.stripMargin
+          ).queryExecution.executedPlan.collect {
+            case exchange: Exchange => exchange
+          }.length
+          assert(numExchanges === 3)
+        }
+
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to