Repository: spark
Updated Branches:
  refs/heads/master c048929c6 -> 8211aab07


[SPARK-9858][SQL] Add an ExchangeCoordinator to estimate the number of 
post-shuffle partitions for aggregates and joins (follow-up)

https://issues.apache.org/jira/browse/SPARK-9858

This PR is the follow-up work of https://github.com/apache/spark/pull/9276. It 
addresses JoshRosen's comments.

Author: Yin Huai <yh...@databricks.com>

Closes #9453 from yhuai/numReducer-followUp.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8211aab0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8211aab0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8211aab0

Branch: refs/heads/master
Commit: 8211aab0793cf64202b99be4f31bb8a9ae77050d
Parents: c048929
Author: Yin Huai <yh...@databricks.com>
Authored: Fri Nov 6 11:13:51 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Fri Nov 6 11:13:51 2015 -0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  |   8 -
 .../apache/spark/sql/execution/Exchange.scala   |  40 +++--
 .../sql/execution/ExchangeCoordinator.scala     |  31 ++--
 .../org/apache/spark/sql/CachedTableSuite.scala | 150 ++++++++++++++-----
 4 files changed, 167 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 9312c81..86b9417 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -165,11 +165,6 @@ sealed trait Partitioning {
    * produced by `A` could have also been produced by `B`.
    */
   def guarantees(other: Partitioning): Boolean = this == other
-
-  def withNumPartitions(newNumPartitions: Int): Partitioning = {
-    throw new IllegalStateException(
-      s"It is not allowed to call withNumPartitions method of a 
${this.getClass.getSimpleName}")
-  }
 }
 
 object Partitioning {
@@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
     case _ => false
   }
 
-  override def withNumPartitions(newNumPartitions: Int): HashPartitioning = {
-    HashPartitioning(expressions, newNumPartitions)
-  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 0f72ec6..a4ce328 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -242,7 +242,7 @@ case class Exchange(
     // update the number of post-shuffle partitions.
     specifiedPartitionStartIndices.foreach { indices =>
       assert(newPartitioning.isInstanceOf[HashPartitioning])
-      newPartitioning = newPartitioning.withNumPartitions(indices.length)
+      newPartitioning = UnknownPartitioning(indices.length)
     }
     new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
   }
@@ -262,7 +262,7 @@ case class Exchange(
 
 object Exchange {
   def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
-    Exchange(newPartitioning, child, None: Option[ExchangeCoordinator])
+    Exchange(newPartitioning, child, coordinator = None: 
Option[ExchangeCoordinator])
   }
 }
 
@@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
             child.outputPartitioning match {
               case hash: HashPartitioning => true
               case collection: PartitioningCollection =>
-                
collection.partitionings.exists(_.isInstanceOf[HashPartitioning])
+                
collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
               case _ => false
             }
         }
@@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: 
SQLContext) extends Rule[
       // First check if the existing partitions of the children all match. 
This means they are
       // partitioned by the same partitioning into the same number of 
partitions. In that case,
       // don't try to make them match `defaultPartitions`, just use the 
existing partitioning.
-      // TODO: this should be a cost based decision. For example, a big 
relation should probably
-      // maintain its existing number of partitions and smaller partitions 
should be shuffled.
-      // defaultPartitions is arbitrary.
-      val numPartitions = children.head.outputPartitioning.numPartitions
+      val maxChildrenNumPartitions = 
children.map(_.outputPartitioning.numPartitions).max
       val useExistingPartitioning = 
children.zip(requiredChildDistributions).forall {
         case (child, distribution) => {
           child.outputPartitioning.guarantees(
-            createPartitioning(distribution, numPartitions))
+            createPartitioning(distribution, maxChildrenNumPartitions))
         }
       }
 
       children = if (useExistingPartitioning) {
+        // We do not need to shuffle any child's output.
         children
       } else {
+        // We need to shuffle at least one child's output.
+        // Now, we will determine the number of partitions that will be used 
by created
+        // partitioning schemes.
+        val numPartitions = {
+          // Let's see if we need to shuffle all child's outputs when we use
+          // maxChildrenNumPartitions.
+          val shufflesAllChildren = 
children.zip(requiredChildDistributions).forall {
+            case (child, distribution) => {
+              !child.outputPartitioning.guarantees(
+                createPartitioning(distribution, maxChildrenNumPartitions))
+            }
+          }
+          // If we need to shuffle all children, we use 
defaultNumPreShufflePartitions as the
+          // number of partitions. Otherwise, we use maxChildrenNumPartitions.
+          if (shufflesAllChildren) defaultNumPreShufflePartitions else 
maxChildrenNumPartitions
+        }
+
         children.zip(requiredChildDistributions).map {
           case (child, distribution) => {
             val targetPartitioning =
-              createPartitioning(distribution, defaultNumPreShufflePartitions)
+              createPartitioning(distribution, numPartitions)
             if (child.outputPartitioning.guarantees(targetPartitioning)) {
               child
             } else {
-              Exchange(targetPartitioning, child)
+              child match {
+                // If child is an exchange, we replace it with
+                // a new one having targetPartitioning.
+                case Exchange(_, c, _) => Exchange(targetPartitioning, c)
+                case _ => Exchange(targetPartitioning, child)
+              }
             }
           }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
index 8dbd69e..827fdd2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution
 
 import java.util.{Map => JMap, HashMap => JHashMap}
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator(
    * Registers an [[Exchange]] operator to this coordinator. This method is 
only allowed to be
    * called in the `doPrepare` method of an [[Exchange]] operator.
    */
+  @GuardedBy("this")
   def registerExchange(exchange: Exchange): Unit = synchronized {
     exchanges += exchange
   }
@@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator(
    */
   private[sql] def estimatePartitionStartIndices(
       mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = {
-    // If we have mapOutputStatistics.length <= numExchange, it is because we 
do not submit
+    // If we have mapOutputStatistics.length < numExchange, it is because we 
do not submit
     // a stage when the number of partitions of this dependency is 0.
     assert(mapOutputStatistics.length <= numExchanges)
 
@@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator(
         val totalPostShuffleInputSize = 
mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
         // The max at here is to make sure that when we have an empty table, we
         // only have a single post-shuffle partition.
+        // There is no particular reason that we pick 16. We just need a 
number to
+        // prevent maxPostShuffleInputSize from being set to 0.
         val maxPostShuffleInputSize =
           math.max(math.ceil(totalPostShuffleInputSize / 
numPartitions.toDouble).toLong, 16)
         math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize)
@@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator(
     // Make sure we do get the same number of pre-shuffle partitions for those 
stages.
     val distinctNumPreShufflePartitions =
       mapOutputStatistics.map(stats => 
stats.bytesByPartitionId.length).distinct
+    // The reason that we are expecting a single value of the number of 
pre-shuffle partitions
+    // is that when we add Exchanges, we set the number of pre-shuffle 
partitions
+    // (i.e. map output partitions) using a static setting, which is the value 
of
+    // spark.sql.shuffle.partitions. Even if two input RDDs are having 
different
+    // number of partitions, they will have the same number of pre-shuffle 
partitions
+    // (i.e. map output partitions).
     assert(
       distinctNumPreShufflePartitions.length == 1,
       "There should be only one distinct value of the number pre-shuffle 
partitions " +
@@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator(
     partitionStartIndices.toArray
   }
 
+  @GuardedBy("this")
   private def doEstimationIfNecessary(): Unit = synchronized {
     // It is unlikely that this method will be called from multiple threads
     // (when multiple threads trigger the execution of THIS physical)
@@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator(
 
       // Wait for the finishes of those submitted map stages.
       val mapOutputStatistics = new 
Array[MapOutputStatistics](submittedStageFutures.length)
-      i = 0
-      while (i < submittedStageFutures.length) {
+      var j = 0
+      while (j < submittedStageFutures.length) {
         // This call is a blocking call. If the stage has not finished, we 
will wait at here.
-        mapOutputStatistics(i) = submittedStageFutures(i).get()
-        i += 1
+        mapOutputStatistics(j) = submittedStageFutures(j).get()
+        j += 1
       }
 
       // Now, we estimate partitionStartIndices. partitionStartIndices.length 
will be the
@@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator(
           Some(estimatePartitionStartIndices(mapOutputStatistics))
         }
 
-      i = 0
-      while (i < numExchanges) {
-        val exchange = exchanges(i)
+      var k = 0
+      while (k < numExchanges) {
+        val exchange = exchanges(k)
         val rdd =
-          exchange.preparePostShuffleRDD(shuffleDependencies(i), 
partitionStartIndices)
+          exchange.preparePostShuffleRDD(shuffleDependencies(k), 
partitionStartIndices)
         newPostShuffleRDDs.put(exchange, rdd)
 
-        i += 1
+        k += 1
       }
 
       // Finally, we set postShuffleRDDs and estimated.

http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index dbcb011..bce94da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._
 import org.apache.spark.Accumulators
 import org.apache.spark.sql.columnar._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext}
 import org.apache.spark.storage.{StorageLevel, RDDBlockId}
 
 private case class BigData(s: String)
 
-class CachedTableSuite extends QueryTest with SharedSQLContext {
+class CachedTableSuite extends QueryTest with SQLTestUtils with 
SharedSQLContext {
   import testImplicits._
 
   def rddIdOf(tableName: String): Int = {
@@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with 
SharedSQLContext {
       sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
       sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY 
key").collect())
     sqlContext.uncacheTable("orderedTable")
+    sqlContext.dropTempTable("orderedTable")
 
     // Set up two tables distributed in the same way. Try this with the data 
distributed into
     // different number of partitions.
     for (numPartitions <- 1 until 10 by 4) {
-      testData.repartition(numPartitions, $"key").registerTempTable("t1")
-      testData2.repartition(numPartitions, $"a").registerTempTable("t2")
+      withTempTable("t1", "t2") {
+        testData.repartition(numPartitions, $"key").registerTempTable("t1")
+        testData2.repartition(numPartitions, $"a").registerTempTable("t2")
+        sqlContext.cacheTable("t1")
+        sqlContext.cacheTable("t2")
+
+        // Joining them should result in no exchanges.
+        verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = 
t2.a"), 0)
+        checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
+          sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+
+        // Grouping on the partition key should result in no exchanges
+        verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
+        checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
+          sql("SELECT count(*) FROM testData GROUP BY key"))
+
+        sqlContext.uncacheTable("t1")
+        sqlContext.uncacheTable("t2")
+      }
+    }
+
+    // Distribute the tables into non-matching number of partitions. Need to 
shuffle one side.
+    withTempTable("t1", "t2") {
+      testData.repartition(6, $"key").registerTempTable("t1")
+      testData2.repartition(3, $"a").registerTempTable("t2")
       sqlContext.cacheTable("t1")
       sqlContext.cacheTable("t2")
 
-      // Joining them should result in no exchanges.
-      verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = 
t2.a"), 0)
-      checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
-        sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+      val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key 
= t2.a")
+      verifyNumExchanges(query, 1)
+      
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+      checkAnswer(
+        query,
+        testData.join(testData2, $"key" === $"a").select($"key", $"value", 
$"a", $"b"))
+      sqlContext.uncacheTable("t1")
+      sqlContext.uncacheTable("t2")
+    }
 
-      // Grouping on the partition key should result in no exchanges
-      verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
-      checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
-        sql("SELECT count(*) FROM testData GROUP BY key"))
+    // One side of join is not partitioned in the desired way. Need to shuffle 
one side.
+    withTempTable("t1", "t2") {
+      testData.repartition(6, $"value").registerTempTable("t1")
+      testData2.repartition(6, $"a").registerTempTable("t2")
+      sqlContext.cacheTable("t1")
+      sqlContext.cacheTable("t2")
 
+      val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key 
= t2.a")
+      verifyNumExchanges(query, 1)
+      
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+      checkAnswer(
+        query,
+        testData.join(testData2, $"key" === $"a").select($"key", $"value", 
$"a", $"b"))
       sqlContext.uncacheTable("t1")
       sqlContext.uncacheTable("t2")
-      sqlContext.dropTempTable("t1")
-      sqlContext.dropTempTable("t2")
     }
 
-    // Distribute the tables into non-matching number of partitions. Need to 
shuffle.
-    testData.repartition(6, $"key").registerTempTable("t1")
-    testData2.repartition(3, $"a").registerTempTable("t2")
-    sqlContext.cacheTable("t1")
-    sqlContext.cacheTable("t2")
+    withTempTable("t1", "t2") {
+      testData.repartition(6, $"value").registerTempTable("t1")
+      testData2.repartition(12, $"a").registerTempTable("t2")
+      sqlContext.cacheTable("t1")
+      sqlContext.cacheTable("t2")
 
-    verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 
2)
-    sqlContext.uncacheTable("t1")
-    sqlContext.uncacheTable("t2")
-    sqlContext.dropTempTable("t1")
-    sqlContext.dropTempTable("t2")
+      val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key 
= t2.a")
+      verifyNumExchanges(query, 1)
+      
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 
12)
+      checkAnswer(
+        query,
+        testData.join(testData2, $"key" === $"a").select($"key", $"value", 
$"a", $"b"))
+      sqlContext.uncacheTable("t1")
+      sqlContext.uncacheTable("t2")
+    }
 
-    // One side of join is not partitioned in the desired way. Need to shuffle.
-    testData.repartition(6, $"value").registerTempTable("t1")
-    testData2.repartition(6, $"a").registerTempTable("t2")
-    sqlContext.cacheTable("t1")
-    sqlContext.cacheTable("t2")
+    // One side of join is not partitioned in the desired way. Since the 
number of partitions of
+    // the side that has already partitioned is smaller than the side that is 
not partitioned,
+    // we shuffle both side.
+    withTempTable("t1", "t2") {
+      testData.repartition(6, $"value").registerTempTable("t1")
+      testData2.repartition(3, $"a").registerTempTable("t2")
+      sqlContext.cacheTable("t1")
+      sqlContext.cacheTable("t2")
 
-    verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 
2)
-    sqlContext.uncacheTable("t1")
-    sqlContext.uncacheTable("t2")
-    sqlContext.dropTempTable("t1")
-    sqlContext.dropTempTable("t2")
+      val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key 
= t2.a")
+      verifyNumExchanges(query, 2)
+      checkAnswer(
+        query,
+        testData.join(testData2, $"key" === $"a").select($"key", $"value", 
$"a", $"b"))
+      sqlContext.uncacheTable("t1")
+      sqlContext.uncacheTable("t2")
+    }
+
+    // repartition's column ordering is different from group by column 
ordering.
+    // But they use the same set of columns.
+    withTempTable("t1") {
+      testData.repartition(6, $"value", $"key").registerTempTable("t1")
+      sqlContext.cacheTable("t1")
+
+      val query = sql("SELECT value, key from t1 group by key, value")
+      verifyNumExchanges(query, 0)
+      checkAnswer(
+        query,
+        testData.distinct().select($"value", $"key"))
+      sqlContext.uncacheTable("t1")
+    }
+
+    // repartition's column ordering is different from join condition's column 
ordering.
+    // We will still shuffle because hashcodes of a row depend on the column 
ordering.
+    // If we do not shuffle, we may actually partition two tables in totally 
two different way.
+    // See PartitioningSuite for more details.
+    withTempTable("t1", "t2") {
+      val df1 = testData
+      df1.repartition(6, $"value", $"key").registerTempTable("t1")
+      val df2 = testData2.select($"a", $"b".cast("string"))
+      df2.repartition(6, $"a", $"b").registerTempTable("t2")
+      sqlContext.cacheTable("t1")
+      sqlContext.cacheTable("t2")
+
+      val query =
+        sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a 
and t1.value = t2.b")
+      verifyNumExchanges(query, 1)
+      
assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+      checkAnswer(
+        query,
+        df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", 
$"value", $"a", $"b"))
+      sqlContext.uncacheTable("t1")
+      sqlContext.uncacheTable("t2")
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to