This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 44634ada5 [VL] Remove a limit for BHJ in stage fallback policy (#7105)
44634ada5 is described below

commit 44634ada5b981e7f86f7d3be565f209f21aa679d
Author: PHILO-HE <[email protected]>
AuthorDate: Wed Sep 4 09:04:04 2024 +0800

    [VL] Remove a limit for BHJ in stage fallback policy (#7105)
---
 .../apache/gluten/execution/FallbackSuite.scala    | 41 +++++++++++-----------
 .../extension/columnar/ExpandFallbackPolicy.scala  | 23 +-----------
 2 files changed, 22 insertions(+), 42 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
index 2b40ac54b..0f1256923 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
@@ -20,10 +20,10 @@ import org.apache.gluten.GlutenConfig
 import org.apache.gluten.extension.GlutenPlan
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, 
ColumnarShuffleExchangeExec, SparkPlan}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
AQEShuffleReadExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
 
 class FallbackSuite extends VeloxWholeStageTransformerSuite with 
AdaptiveSparkPlanHelper {
   protected val rootPath: String = getClass.getResource("/").getPath
@@ -106,35 +106,36 @@ class FallbackSuite extends 
VeloxWholeStageTransformerSuite with AdaptiveSparkPl
     }
   }
 
-  test("fallback with bhj") {
-    withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> 
"2") {
+  test("offload BroadcastExchange and fall back BHJ") {
+    withSQLConf(
+      "spark.gluten.sql.columnar.broadcastJoin" -> "false"
+    ) {
       runQueryAndCompare(
         """
-          |SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) 
FROM tmp1
-          |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
+          |SELECT java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) 
FROM tmp1
+          |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 limit 10
           |""".stripMargin
       ) {
         df =>
           val plan = df.queryExecution.executedPlan
-          val bhj = find(plan) {
+          val columnarBhj = find(plan) {
             case _: BroadcastHashJoinExecTransformerBase => true
             case _ => false
           }
-          assert(bhj.isDefined)
-          val columnarToRow = collectColumnarToRow(bhj.get)
-          assert(columnarToRow == 0)
+          assert(!columnarBhj.isDefined)
 
-          val wholeQueryColumnarToRow = collectColumnarToRow(plan)
-          assert(wholeQueryColumnarToRow == 1)
-      }
+          val vanillaBhj = find(plan) {
+            case _: BroadcastHashJoinExec => true
+            case _ => false
+          }
+          assert(vanillaBhj.isDefined)
 
-      // before the fix, it would fail
-      spark
-        .sql("""
-               |SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, 
tmp2.c1) FROM tmp1
-               |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
-               |""".stripMargin)
-        .show()
+          val columnarBroadcastExchange = find(plan) {
+            case _: ColumnarBroadcastExchangeExec => true
+            case _ => false
+          }
+          assert(columnarBroadcastExchange.isDefined)
+      }
     }
   }
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala
index 491b54443..29e1caae7 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala
@@ -17,7 +17,6 @@
 package org.apache.gluten.extension.columnar
 
 import org.apache.gluten.GlutenConfig
-import org.apache.gluten.execution.BroadcastHashJoinExecTransformerBase
 import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, 
RowToColumnarLike, Transitions}
 import org.apache.gluten.utils.PlanUtil
@@ -27,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
BroadcastQueryStageExec, QueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.command.ExecutedCommandExec
 import org.apache.spark.sql.execution.exchange.Exchange
@@ -179,21 +178,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: 
Boolean, originalPlan: SparkP
     stageFallbackTransitionCost
   }
 
-  private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = 
{
-    def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
-      case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => 
true
-      case _ => false
-    }
-
-    plan.find {
-      case j: BroadcastHashJoinExecTransformerBase
-          if isColumnarBroadcastExchange(j.left) ||
-            isColumnarBroadcastExchange(j.right) =>
-        true
-      case _ => false
-    }.isDefined
-  }
-
   private def fallback(plan: SparkPlan): FallbackInfo = {
     val fallbackThreshold = if (isAdaptiveContext) {
       GlutenConfig.getConf.wholeStageFallbackThreshold
@@ -210,11 +194,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: 
Boolean, originalPlan: SparkP
       return FallbackInfo.DO_NOT_FALLBACK()
     }
 
-    // not safe to fallback row-based BHJ as the broadcast exchange is already 
columnar
-    if (hasColumnarBroadcastExchangeWithJoin(plan)) {
-      return FallbackInfo.DO_NOT_FALLBACK()
-    }
-
     val transitionCost = countTransitionCost(plan)
     val fallbackTransitionCost = if (isAdaptiveContext) {
       countStageFallbackTransitionCost(plan)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to