This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 73bb619d45b2 [SPARK-48235][SQL] Directly pass join instead of all 
arguments to getBroadcastBuildSide and getShuffleHashJoinBuildSide
73bb619d45b2 is described below

commit 73bb619d45b2d0699ca4a9d251eea57c359f275b
Author: fred-db <fredrik.kla...@databricks.com>
AuthorDate: Fri May 10 07:45:28 2024 -0700

    [SPARK-48235][SQL] Directly pass join instead of all arguments to 
getBroadcastBuildSide and getShuffleHashJoinBuildSide
    
    ### What changes were proposed in this pull request?
    
    * Refactor getBroadcastBuildSide and getShuffleHashJoinBuildSide to pass 
the join as argument instead of all member variables of the join separately.
    
    ### Why are the changes needed?
    
    * Makes to code easier to read.
    
    ### Does this PR introduce _any_ user-facing change?
    
    * no
    
    ### How was this patch tested?
    
    * Existing UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    * No
    
    Closes #46525 from fred-db/parameter-change.
    
    Authored-by: fred-db <fredrik.kla...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/optimizer/joins.scala       | 56 +++++++++-----------
 .../optimizer/JoinSelectionHelperSuite.scala       | 59 +++++-----------------
 .../spark/sql/execution/SparkStrategies.scala      |  6 +--
 3 files changed, 40 insertions(+), 81 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 2b4ee033b088..5571178832db 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -289,58 +289,52 @@ case object BuildLeft extends BuildSide
 trait JoinSelectionHelper {
 
   def getBroadcastBuildSide(
-      left: LogicalPlan,
-      right: LogicalPlan,
-      joinType: JoinType,
-      hint: JoinHint,
+      join: Join,
       hintOnly: Boolean,
       conf: SQLConf): Option[BuildSide] = {
     val buildLeft = if (hintOnly) {
-      hintToBroadcastLeft(hint)
+      hintToBroadcastLeft(join.hint)
     } else {
-      canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint)
+      canBroadcastBySize(join.left, conf) && !hintToNotBroadcastLeft(join.hint)
     }
     val buildRight = if (hintOnly) {
-      hintToBroadcastRight(hint)
+      hintToBroadcastRight(join.hint)
     } else {
-      canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint)
+      canBroadcastBySize(join.right, conf) && 
!hintToNotBroadcastRight(join.hint)
     }
     getBuildSide(
-      canBuildBroadcastLeft(joinType) && buildLeft,
-      canBuildBroadcastRight(joinType) && buildRight,
-      left,
-      right
+      canBuildBroadcastLeft(join.joinType) && buildLeft,
+      canBuildBroadcastRight(join.joinType) && buildRight,
+      join.left,
+      join.right
     )
   }
 
   def getShuffleHashJoinBuildSide(
-      left: LogicalPlan,
-      right: LogicalPlan,
-      joinType: JoinType,
-      hint: JoinHint,
+      join: Join,
       hintOnly: Boolean,
       conf: SQLConf): Option[BuildSide] = {
     val buildLeft = if (hintOnly) {
-      hintToShuffleHashJoinLeft(hint)
+      hintToShuffleHashJoinLeft(join.hint)
     } else {
-      hintToPreferShuffleHashJoinLeft(hint) ||
-        (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(left, conf) &&
-          muchSmaller(left, right, conf)) ||
+      hintToPreferShuffleHashJoinLeft(join.hint) ||
+        (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.left, 
conf) &&
+          muchSmaller(join.left, join.right, conf)) ||
         forceApplyShuffledHashJoin(conf)
     }
     val buildRight = if (hintOnly) {
-      hintToShuffleHashJoinRight(hint)
+      hintToShuffleHashJoinRight(join.hint)
     } else {
-      hintToPreferShuffleHashJoinRight(hint) ||
-        (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(right, conf) 
&&
-          muchSmaller(right, left, conf)) ||
+      hintToPreferShuffleHashJoinRight(join.hint) ||
+        (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.right, 
conf) &&
+          muchSmaller(join.right, join.left, conf)) ||
         forceApplyShuffledHashJoin(conf)
     }
     getBuildSide(
-      canBuildShuffledHashJoinLeft(joinType) && buildLeft,
-      canBuildShuffledHashJoinRight(joinType) && buildRight,
-      left,
-      right
+      canBuildShuffledHashJoinLeft(join.joinType) && buildLeft,
+      canBuildShuffledHashJoinRight(join.joinType) && buildRight,
+      join.left,
+      join.right
     )
   }
 
@@ -401,10 +395,8 @@ trait JoinSelectionHelper {
   }
 
   def canPlanAsBroadcastHashJoin(join: Join, conf: SQLConf): Boolean = {
-    getBroadcastBuildSide(join.left, join.right, join.joinType,
-      join.hint, hintOnly = true, conf).isDefined ||
-      getBroadcastBuildSide(join.left, join.right, join.joinType,
-        join.hint, hintOnly = false, conf).isDefined
+    getBroadcastBuildSide(join, hintOnly = true, conf).isDefined ||
+      getBroadcastBuildSide(join, hintOnly = false, conf).isDefined
   }
 
   def canPruneLeft(joinType: JoinType): Boolean = joinType match {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
index 6acce44922f6..61fb68cfba86 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.AttributeMap
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
-import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, 
JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, 
JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH}
 import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
 import org.apache.spark.sql.internal.SQLConf
 
@@ -38,16 +38,15 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
     size = Some(1000),
     attributeStats = AttributeMap(Seq()))
 
+  private val join = Join(left, right, Inner, None, JoinHint(None, None))
+
   private val hintBroadcast = Some(HintInfo(Some(BROADCAST)))
   private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH)))
   private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH)))
 
   test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a 
left hint") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(hintBroadcast, None),
+      join.copy(hint = JoinHint(hintBroadcast, None)),
       hintOnly = true,
       SQLConf.get
     )
@@ -56,10 +55,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a 
right hint") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, hintBroadcast),
+      join.copy(hint = JoinHint(None, hintBroadcast)),
       hintOnly = true,
       SQLConf.get
     )
@@ -68,10 +64,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getBroadcastBuildSide (hintOnly = true) return smaller side with both 
having hints") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(hintBroadcast, hintBroadcast),
+      join.copy(hint = JoinHint(hintBroadcast, hintBroadcast)),
       hintOnly = true,
       SQLConf.get
     )
@@ -80,10 +73,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getBroadcastBuildSide (hintOnly = true) return None when no side has a 
hint") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, None),
+      join.copy(hint = JoinHint(None, None)),
       hintOnly = true,
       SQLConf.get
     )
@@ -92,10 +82,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right 
is broadcastable") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, None),
+      join.copy(hint = JoinHint(None, None)),
       hintOnly = false,
       SQLConf.get
     )
@@ -105,10 +92,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
   test("getBroadcastBuildSide (hintOnly = false) return None when right has no 
broadcast hint") {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
       val broadcastSide = getBroadcastBuildSide(
-        left,
-        right,
-        Inner,
-        JoinHint(None, hintNotToBroadcast ),
+        join.copy(hint = JoinHint(None, hintNotToBroadcast)),
         hintOnly = false,
         SQLConf.get
       )
@@ -118,10 +102,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with 
only a left hint") {
     val broadcastSide = getShuffleHashJoinBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(hintShuffleHash, None),
+      join.copy(hint = JoinHint(hintShuffleHash, None)),
       hintOnly = true,
       SQLConf.get
     )
@@ -130,10 +111,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with 
only a right hint") {
     val broadcastSide = getShuffleHashJoinBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, hintShuffleHash),
+      join.copy(hint = JoinHint(None, hintShuffleHash)),
       hintOnly = true,
       SQLConf.get
     )
@@ -142,10 +120,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when 
both have hints") {
     val broadcastSide = getShuffleHashJoinBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(hintShuffleHash, hintShuffleHash),
+      join.copy(hint = JoinHint(hintShuffleHash, hintShuffleHash)),
       hintOnly = true,
       SQLConf.get
     )
@@ -154,10 +129,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side 
has a hint") {
     val broadcastSide = getShuffleHashJoinBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, None),
+      join.copy(hint = JoinHint(None, None)),
       hintOnly = true,
       SQLConf.get
     )
@@ -166,10 +138,7 @@ class JoinSelectionHelperSuite extends PlanTest with 
JoinSelectionHelper {
 
   test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when 
right is smaller") {
     val broadcastSide = getBroadcastBuildSide(
-      left,
-      right,
-      Inner,
-      JoinHint(None, None),
+      join.copy(hint = JoinHint(None, None)),
       hintOnly = false,
       SQLConf.get
     )
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 348cc00a1f97..9e14d13b5cb1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -248,8 +248,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys)
         def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = {
           if (hashJoinSupport) {
-            val buildSide = getBroadcastBuildSide(
-              left, right, joinType, hint, onlyLookingAtHint, conf)
+            val buildSide = getBroadcastBuildSide(j, onlyLookingAtHint, conf)
             checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, 
true)
             buildSide.map {
               buildSide =>
@@ -269,8 +268,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
         def createShuffleHashJoin(onlyLookingAtHint: Boolean) = {
           if (hashJoinSupport) {
-            val buildSide = getShuffleHashJoinBuildSide(
-              left, right, joinType, hint, onlyLookingAtHint, conf)
+            val buildSide = getShuffleHashJoinBuildSide(j, onlyLookingAtHint, 
conf)
             checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, 
false)
             buildSide.map {
               buildSide =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to