This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 73bb619d45b2 [SPARK-48235][SQL] Directly pass join instead of all arguments to getBroadcastBuildSide and getShuffleHashJoinBuildSide 73bb619d45b2 is described below commit 73bb619d45b2d0699ca4a9d251eea57c359f275b Author: fred-db <fredrik.kla...@databricks.com> AuthorDate: Fri May 10 07:45:28 2024 -0700 [SPARK-48235][SQL] Directly pass join instead of all arguments to getBroadcastBuildSide and getShuffleHashJoinBuildSide ### What changes were proposed in this pull request? * Refactor getBroadcastBuildSide and getShuffleHashJoinBuildSide to pass the join as argument instead of all member variables of the join separately. ### Why are the changes needed? * Makes to code easier to read. ### Does this PR introduce _any_ user-facing change? * no ### How was this patch tested? * Existing UTs ### Was this patch authored or co-authored using generative AI tooling? * No Closes #46525 from fred-db/parameter-change. Authored-by: fred-db <fredrik.kla...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../spark/sql/catalyst/optimizer/joins.scala | 56 +++++++++----------- .../optimizer/JoinSelectionHelperSuite.scala | 59 +++++----------------- .../spark/sql/execution/SparkStrategies.scala | 6 +-- 3 files changed, 40 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2b4ee033b088..5571178832db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -289,58 +289,52 @@ case object BuildLeft extends BuildSide trait JoinSelectionHelper { def getBroadcastBuildSide( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - hint: JoinHint, + join: Join, hintOnly: Boolean, conf: SQLConf): Option[BuildSide] = { val buildLeft = if (hintOnly) { - hintToBroadcastLeft(hint) + hintToBroadcastLeft(join.hint) } else { - canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint) + canBroadcastBySize(join.left, conf) && !hintToNotBroadcastLeft(join.hint) } val buildRight = if (hintOnly) { - hintToBroadcastRight(hint) + hintToBroadcastRight(join.hint) } else { - canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) + canBroadcastBySize(join.right, conf) && !hintToNotBroadcastRight(join.hint) } getBuildSide( - canBuildBroadcastLeft(joinType) && buildLeft, - canBuildBroadcastRight(joinType) && buildRight, - left, - right + canBuildBroadcastLeft(join.joinType) && buildLeft, + canBuildBroadcastRight(join.joinType) && buildRight, + join.left, + join.right ) } def getShuffleHashJoinBuildSide( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - hint: JoinHint, + join: Join, hintOnly: Boolean, conf: SQLConf): Option[BuildSide] = { val buildLeft = if (hintOnly) { - hintToShuffleHashJoinLeft(hint) + hintToShuffleHashJoinLeft(join.hint) } else { - hintToPreferShuffleHashJoinLeft(hint) || - (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(left, conf) && - muchSmaller(left, right, conf)) || + hintToPreferShuffleHashJoinLeft(join.hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.left, conf) && + muchSmaller(join.left, join.right, conf)) || forceApplyShuffledHashJoin(conf) } val buildRight = if (hintOnly) { - hintToShuffleHashJoinRight(hint) + hintToShuffleHashJoinRight(join.hint) } else { - hintToPreferShuffleHashJoinRight(hint) || - (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(right, conf) && - muchSmaller(right, left, conf)) || + hintToPreferShuffleHashJoinRight(join.hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.right, conf) && + muchSmaller(join.right, join.left, conf)) || forceApplyShuffledHashJoin(conf) } getBuildSide( - canBuildShuffledHashJoinLeft(joinType) && buildLeft, - canBuildShuffledHashJoinRight(joinType) && buildRight, - left, - right + canBuildShuffledHashJoinLeft(join.joinType) && buildLeft, + canBuildShuffledHashJoinRight(join.joinType) && buildRight, + join.left, + join.right ) } @@ -401,10 +395,8 @@ trait JoinSelectionHelper { } def canPlanAsBroadcastHashJoin(join: Join, conf: SQLConf): Boolean = { - getBroadcastBuildSide(join.left, join.right, join.joinType, - join.hint, hintOnly = true, conf).isDefined || - getBroadcastBuildSide(join.left, join.right, join.joinType, - join.hint, hintOnly = false, conf).isDefined + getBroadcastBuildSide(join, hintOnly = true, conf).isDefined || + getBroadcastBuildSide(join, hintOnly = false, conf).isDefined } def canPruneLeft(joinType: JoinType): Boolean = joinType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala index 6acce44922f6..61fb68cfba86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.internal.SQLConf @@ -38,16 +38,15 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { size = Some(1000), attributeStats = AttributeMap(Seq())) + private val join = Join(left, right, Inner, None, JoinHint(None, None)) + private val hintBroadcast = Some(HintInfo(Some(BROADCAST))) private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH))) private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH))) test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(hintBroadcast, None), + join.copy(hint = JoinHint(hintBroadcast, None)), hintOnly = true, SQLConf.get ) @@ -56,10 +55,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, hintBroadcast), + join.copy(hint = JoinHint(None, hintBroadcast)), hintOnly = true, SQLConf.get ) @@ -68,10 +64,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(hintBroadcast, hintBroadcast), + join.copy(hint = JoinHint(hintBroadcast, hintBroadcast)), hintOnly = true, SQLConf.get ) @@ -80,10 +73,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = true, SQLConf.get ) @@ -92,10 +82,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = false, SQLConf.get ) @@ -105,10 +92,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, hintNotToBroadcast ), + join.copy(hint = JoinHint(None, hintNotToBroadcast)), hintOnly = false, SQLConf.get ) @@ -118,10 +102,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(hintShuffleHash, None), + join.copy(hint = JoinHint(hintShuffleHash, None)), hintOnly = true, SQLConf.get ) @@ -130,10 +111,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(None, hintShuffleHash), + join.copy(hint = JoinHint(None, hintShuffleHash)), hintOnly = true, SQLConf.get ) @@ -142,10 +120,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(hintShuffleHash, hintShuffleHash), + join.copy(hint = JoinHint(hintShuffleHash, hintShuffleHash)), hintOnly = true, SQLConf.get ) @@ -154,10 +129,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") { val broadcastSide = getShuffleHashJoinBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = true, SQLConf.get ) @@ -166,10 +138,7 @@ class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") { val broadcastSide = getBroadcastBuildSide( - left, - right, - Inner, - JoinHint(None, None), + join.copy(hint = JoinHint(None, None)), hintOnly = false, SQLConf.get ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 348cc00a1f97..9e14d13b5cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -248,8 +248,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val hashJoinSupport = hashJoinSupported(leftKeys, rightKeys) def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { - val buildSide = getBroadcastBuildSide( - left, right, joinType, hint, onlyLookingAtHint, conf) + val buildSide = getBroadcastBuildSide(j, onlyLookingAtHint, conf) checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, true) buildSide.map { buildSide => @@ -269,8 +268,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { - val buildSide = getShuffleHashJoinBuildSide( - left, right, joinType, hint, onlyLookingAtHint, conf) + val buildSide = getShuffleHashJoinBuildSide(j, onlyLookingAtHint, conf) checkHintBuildSide(onlyLookingAtHint, buildSide, joinType, hint, false) buildSide.map { buildSide => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org