This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a8d69132fb3 [SPARK-42136] Refactor BroadcastHashJoinExec output 
partitioning calculation
a8d69132fb3 is described below

commit a8d69132fb33cc5a7ff365715840ba99c6514c49
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Tue Feb 7 21:28:58 2023 +0800

    [SPARK-42136] Refactor BroadcastHashJoinExec output partitioning calculation
    
    ### What changes were proposed in this pull request?
    This is PR refactors `BroadcastHashJoinExec` output partitioning 
calculation using the new `TreeNode.multiTransformDown()` helper so simplify 
code and improve performance.
    
    ### Why are the changes needed?
    Simpler code with `TreeNode.multiTransform()`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs.
    
    Closes #38038 from 
peter-toth/SPARK-refactor-broadcasthashjoinexec-output-partitioning.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/joins/BroadcastHashJoinExec.scala    | 27 ++++------------------
 .../sql/execution/joins/BroadcastJoinSuite.scala   |  6 ++---
 2 files changed, 8 insertions(+), 25 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 69c760b5a00..08eaacca2f4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -112,28 +112,11 @@ case class BroadcastHashJoinExec(
   // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", 
"y").
   // The expanded expressions are returned as PartitioningCollection.
   private def expandOutputPartitioning(partitioning: HashPartitioning): 
PartitioningCollection = {
-    val maxNumCombinations = 
conf.broadcastHashJoinOutputPartitioningExpandLimit
-    var currentNumCombinations = 0
-
-    def generateExprCombinations(
-        current: Seq[Expression],
-        accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
-      if (currentNumCombinations >= maxNumCombinations) {
-        Nil
-      } else if (current.isEmpty) {
-        currentNumCombinations += 1
-        Seq(accumulated)
-      } else {
-        val buildKeysOpt = 
streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
-        generateExprCombinations(current.tail, accumulated :+ current.head) ++
-          buildKeysOpt.map(_.flatMap(b => 
generateExprCombinations(current.tail, accumulated :+ b)))
-            .getOrElse(Nil)
-      }
-    }
-
-    PartitioningCollection(
-      generateExprCombinations(partitioning.expressions, Nil)
-        .map(HashPartitioning(_, partitioning.numPartitions)))
+    PartitioningCollection(partitioning.multiTransformDown {
+      case e: Expression if 
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
+        e +: streamedKeyToBuildKeyMapping(e.canonicalized)
+    }.asInstanceOf[Stream[HashPartitioning]]
+      .take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 6333808b420..47714c669d5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -554,8 +554,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
       right = DummySparkPlan())
     var expected = PartitioningCollection(Seq(
       HashPartitioning(Seq(l1, l2, l3), 1),
-      HashPartitioning(Seq(l1, l2, r2), 1),
       HashPartitioning(Seq(l1, r1, l3), 1),
+      HashPartitioning(Seq(l1, l2, r2), 1),
       HashPartitioning(Seq(l1, r1, r2), 1)))
     assert(bhj.outputPartitioning === expected)
 
@@ -571,8 +571,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
       right = DummySparkPlan())
     expected = PartitioningCollection(Seq(
       HashPartitioning(Seq(l1, l2), 1),
-      HashPartitioning(Seq(l1, r2), 1),
       HashPartitioning(Seq(r1, l2), 1),
+      HashPartitioning(Seq(l1, r2), 1),
       HashPartitioning(Seq(r1, r2), 1),
       HashPartitioning(Seq(l3), 1),
       HashPartitioning(Seq(r3), 1)))
@@ -623,8 +623,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
 
     val expected = Seq(
       HashPartitioning(Seq(l1, l2), 1),
-      HashPartitioning(Seq(l1, r2), 1),
       HashPartitioning(Seq(r1, l2), 1),
+      HashPartitioning(Seq(l1, r2), 1),
       HashPartitioning(Seq(r1, r2), 1))
 
     Seq(1, 2, 3, 4).foreach { limit =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to