Repository: spark Updated Branches: refs/heads/master c048929c6 -> 8211aab07
[SPARK-9858][SQL] Add an ExchangeCoordinator to estimate the number of post-shuffle partitions for aggregates and joins (follow-up) https://issues.apache.org/jira/browse/SPARK-9858 This PR is the follow-up work of https://github.com/apache/spark/pull/9276. It addresses JoshRosen's comments. Author: Yin Huai <yh...@databricks.com> Closes #9453 from yhuai/numReducer-followUp. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8211aab0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8211aab0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8211aab0 Branch: refs/heads/master Commit: 8211aab0793cf64202b99be4f31bb8a9ae77050d Parents: c048929 Author: Yin Huai <yh...@databricks.com> Authored: Fri Nov 6 11:13:51 2015 -0800 Committer: Yin Huai <yh...@databricks.com> Committed: Fri Nov 6 11:13:51 2015 -0800 ---------------------------------------------------------------------- .../catalyst/plans/physical/partitioning.scala | 8 - .../apache/spark/sql/execution/Exchange.scala | 40 +++-- .../sql/execution/ExchangeCoordinator.scala | 31 ++-- .../org/apache/spark/sql/CachedTableSuite.scala | 150 ++++++++++++++----- 4 files changed, 167 insertions(+), 62 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 9312c81..86b9417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -165,11 +165,6 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other - - def withNumPartitions(newNumPartitions: Int): Partitioning = { - throw new IllegalStateException( - s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}") - } } object Partitioning { @@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def withNumPartitions(newNumPartitions: Int): HashPartitioning = { - HashPartitioning(expressions, newNumPartitions) - } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 0f72ec6..a4ce328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -242,7 +242,7 @@ case class Exchange( // update the number of post-shuffle partitions. specifiedPartitionStartIndices.foreach { indices => assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = newPartitioning.withNumPartitions(indices.length) + newPartitioning = UnknownPartitioning(indices.length) } new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } @@ -262,7 +262,7 @@ case class Exchange( object Exchange { def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, None: Option[ExchangeCoordinator]) + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } } @@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child.outputPartitioning match { case hash: HashPartitioning => true case collection: PartitioningCollection => - collection.partitionings.exists(_.isInstanceOf[HashPartitioning]) + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) case _ => false } } @@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // First check if the existing partitions of the children all match. This means they are // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. - // TODO: this should be a cost based decision. For example, a big relation should probably - // maintain its existing number of partitions and smaller partitions should be shuffled. - // defaultPartitions is arbitrary. - val numPartitions = children.head.outputPartitioning.numPartitions + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max val useExistingPartitioning = children.zip(requiredChildDistributions).forall { case (child, distribution) => { child.outputPartitioning.guarantees( - createPartitioning(distribution, numPartitions)) + createPartitioning(distribution, maxChildrenNumPartitions)) } } children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. children } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + children.zip(requiredChildDistributions).map { case (child, distribution) => { val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + createPartitioning(distribution, numPartitions) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { - Exchange(targetPartitioning, child) + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case Exchange(_, c, _) => Exchange(targetPartitioning, c) + case _ => Exchange(targetPartitioning, child) + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala index 8dbd69e..827fdd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.{Map => JMap, HashMap => JHashMap} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator( * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be * called in the `doPrepare` method of an [[Exchange]] operator. */ + @GuardedBy("this") def registerExchange(exchange: Exchange): Unit = synchronized { exchanges += exchange } @@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator( */ private[sql] def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. assert(mapOutputStatistics.length <= numExchanges) @@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator( val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum // The max at here is to make sure that when we have an empty table, we // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. val maxPostShuffleInputSize = math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) @@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator( // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). assert( distinctNumPreShufflePartitions.length == 1, "There should be only one distinct value of the number pre-shuffle partitions " + @@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator( partitionStartIndices.toArray } + @GuardedBy("this") private def doEstimationIfNecessary(): Unit = synchronized { // It is unlikely that this method will be called from multiple threads // (when multiple threads trigger the execution of THIS physical) @@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator( // Wait for the finishes of those submitted map stages. val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) - i = 0 - while (i < submittedStageFutures.length) { + var j = 0 + while (j < submittedStageFutures.length) { // This call is a blocking call. If the stage has not finished, we will wait at here. - mapOutputStatistics(i) = submittedStageFutures(i).get() - i += 1 + mapOutputStatistics(j) = submittedStageFutures(j).get() + j += 1 } // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the @@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator( Some(estimatePartitionStartIndices(mapOutputStatistics)) } - i = 0 - while (i < numExchanges) { - val exchange = exchanges(i) + var k = 0 + while (k < numExchanges) { + val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) newPostShuffleRDDs.put(exchange, rdd) - i += 1 + k += 1 } // Finally, we set postShuffleRDDs and estimated. http://git-wip-us.apache.org/repos/asf/spark/blob/8211aab0/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index dbcb011..bce94da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SharedSQLContext { +class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { @@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) sqlContext.uncacheTable("orderedTable") + sqlContext.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - testData.repartition(numPartitions, $"key").registerTempTable("t1") - testData2.repartition(numPartitions, $"a").registerTempTable("t2") + withTempTable("t1", "t2") { + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") - // Joining them should result in no exchanges. - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) - checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), - sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // Grouping on the partition key should result in no exchanges - verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) - checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), - sql("SELECT count(*) FROM testData GROUP BY key")) + // One side of join is not partitioned in the desired way. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") } - // Distribute the tables into non-matching number of partitions. Need to shuffle. - testData.repartition(6, $"key").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(12, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // One side of join is not partitioned in the desired way. Need to shuffle. - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + // One side of join is not partitioned in the desired way. Since the number of partitions of + // the side that has already partitioned is smaller than the side that is not partitioned, + // we shuffle both side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 2) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // repartition's column ordering is different from group by column ordering. + // But they use the same set of columns. + withTempTable("t1") { + testData.repartition(6, $"value", $"key").registerTempTable("t1") + sqlContext.cacheTable("t1") + + val query = sql("SELECT value, key from t1 group by key, value") + verifyNumExchanges(query, 0) + checkAnswer( + query, + testData.distinct().select($"value", $"key")) + sqlContext.uncacheTable("t1") + } + + // repartition's column ordering is different from join condition's column ordering. + // We will still shuffle because hashcodes of a row depend on the column ordering. + // If we do not shuffle, we may actually partition two tables in totally two different way. + // See PartitioningSuite for more details. + withTempTable("t1", "t2") { + val df1 = testData + df1.repartition(6, $"value", $"key").registerTempTable("t1") + val df2 = testData2.select($"a", $"b".cast("string")) + df2.repartition(6, $"a", $"b").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = + sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org