This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a8d69132fb3 [SPARK-42136] Refactor BroadcastHashJoinExec output partitioning calculation a8d69132fb3 is described below commit a8d69132fb33cc5a7ff365715840ba99c6514c49 Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Tue Feb 7 21:28:58 2023 +0800 [SPARK-42136] Refactor BroadcastHashJoinExec output partitioning calculation ### What changes were proposed in this pull request? This is PR refactors `BroadcastHashJoinExec` output partitioning calculation using the new `TreeNode.multiTransformDown()` helper so simplify code and improve performance. ### Why are the changes needed? Simpler code with `TreeNode.multiTransform()`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. Closes #38038 from peter-toth/SPARK-refactor-broadcasthashjoinexec-output-partitioning. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../execution/joins/BroadcastHashJoinExec.scala | 27 ++++------------------ .../sql/execution/joins/BroadcastJoinSuite.scala | 6 ++--- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 69c760b5a00..08eaacca2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -112,28 +112,11 @@ case class BroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit - var currentNumCombinations = 0 - - def generateExprCombinations( - current: Seq[Expression], - accumulated: Seq[Expression]): Seq[Seq[Expression]] = { - if (currentNumCombinations >= maxNumCombinations) { - Nil - } else if (current.isEmpty) { - currentNumCombinations += 1 - Seq(accumulated) - } else { - val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) - generateExprCombinations(current.tail, accumulated :+ current.head) ++ - buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) - .getOrElse(Nil) - } - } - - PartitioningCollection( - generateExprCombinations(partitioning.expressions, Nil) - .map(HashPartitioning(_, partitioning.numPartitions))) + PartitioningCollection(partitioning.multiTransformDown { + case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => + e +: streamedKeyToBuildKeyMapping(e.canonicalized) + }.asInstanceOf[Stream[HashPartitioning]] + .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6333808b420..47714c669d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -554,8 +554,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) var expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2, l3), 1), - HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, l3), 1), + HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, r2), 1))) assert(bhj.outputPartitioning === expected) @@ -571,8 +571,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1), HashPartitioning(Seq(l3), 1), HashPartitioning(Seq(r3), 1))) @@ -623,8 +623,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val expected = Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1)) Seq(1, 2, 3, 4).foreach { limit => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org