This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 961dcdfb9845 [SPARK-45882][SQL] BroadcastHashJoinExec propagate partitioning should respect CoalescedHashPartitioning 961dcdfb9845 is described below commit 961dcdfb98455f341c3f6279fa65aa1dd58ca199 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Tue Nov 14 05:42:13 2023 +0800 [SPARK-45882][SQL] BroadcastHashJoinExec propagate partitioning should respect CoalescedHashPartitioning ### What changes were proposed in this pull request? Add HashPartitioningLike trait and make HashPartitioning and CoalescedHashPartitioning extend it. When we propagate output partiitoning, we should handle HashPartitioningLike instead of HashPartitioning. This pr also changes the BroadcastHashJoinExec to use HashPartitioningLike to avoid regression. ### Why are the changes needed? Avoid unnecessary shuffle exchange. ### Does this PR introduce _any_ user-facing change? yes, avoid regression ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #43753 from ulysses-you/partitioning. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/plans/physical/partitioning.scala | 46 ++++++++++------------ .../execution/joins/BroadcastHashJoinExec.scala | 11 +++--- .../scala/org/apache/spark/sql/JoinSuite.scala | 28 ++++++++++++- 3 files changed, 54 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0ae2857161c8..60e6e42bedf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -258,18 +258,8 @@ case object SinglePartition extends Partitioning { SinglePartitionShuffleSpec } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - * - * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires - * stateful operators to retain the same physical partitioning during the lifetime of the query - * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged - * across Spark versions. Violation of this requirement may bring silent correctness issue. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +trait HashPartitioningLike extends Expression with Partitioning with Unevaluable { + def expressions: Seq[Expression] override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -294,6 +284,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } } +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends HashPartitioningLike { override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = HashShuffleSpec(this, distribution) @@ -306,7 +310,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) - } case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) @@ -316,25 +319,18 @@ case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) * fewer number of partitions. */ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary]) - extends Expression with Partitioning with Unevaluable { - - override def children: Seq[Expression] = from.expressions - override def nullable: Boolean = from.nullable - override def dataType: DataType = from.dataType + extends HashPartitioningLike { - override def satisfies0(required: Distribution): Boolean = from.satisfies0(required) + override def expressions: Seq[Expression] = from.expressions override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = - copy(from = from.copy(expressions = newChildren)) - override val numPartitions: Int = partitions.length - override def toString: String = from.toString - override def sql: String = from.sql + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = + copy(from = from.copy(expressions = newChildren)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 68022757ff24..368534d05b1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -73,7 +73,7 @@ case class BroadcastHashJoinExec( joinType match { case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { - case h: HashPartitioning => expandOutputPartitioning(h) + case h: HashPartitioningLike => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) case other => other } @@ -99,7 +99,7 @@ case class BroadcastHashJoinExec( private def expandOutputPartitioning( partitioning: PartitioningCollection): PartitioningCollection = { PartitioningCollection(partitioning.partitionings.flatMap { - case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) case other => Seq(other) }) @@ -111,11 +111,12 @@ case class BroadcastHashJoinExec( // the expanded partitioning will have the following expressions: // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. - private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { PartitioningCollection(partitioning.multiTransformDown { case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => e +: streamedKeyToBuildKeyMapping(e.canonicalized) - }.asInstanceOf[LazyList[HashPartitioning]] + }.asInstanceOf[LazyList[HashPartitioningLike]] .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index c41b85f75e58..909a05ce26f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -1729,4 +1729,30 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(joined, expected) } + + test("SPARK-45882: BroadcastHashJoinExec propagate partitioning should respect " + + "CoalescedHashPartitioning") { + val cached = spark.sql( + """ + |select /*+ broadcast(testData) */ key, value, a + |from testData join ( + | select a from testData2 group by a + |)tmp on key = a + |""".stripMargin).cache() + try { + val df = cached.groupBy("key").count() + val expected = Seq(Row(1, 1), Row(2, 1), Row(3, 1)) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.size == 1, df.queryExecution) + checkAnswer(df, expected) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.isEmpty, df.queryExecution) + } finally { + cached.unpersist() + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org