This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 4ca65c69a33d [SPARK-45592][SPARK-45282][SQL] Correctness issue in AQE 
with InMemoryTableScanExec
4ca65c69a33d is described below

commit 4ca65c69a33da33f66969477bc8a6f88154ed305
Author: Maryann Xue <maryann....@gmail.com>
AuthorDate: Tue Nov 14 08:51:26 2023 -0800

    [SPARK-45592][SPARK-45282][SQL] Correctness issue in AQE with 
InMemoryTableScanExec
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an correctness issue while enabling AQE for SQL Cache. This 
issue was caused by AQE coalescing the top-level shuffle in the physical plan 
of InMemoryTableScan and wrongfully reported the output partitioning of that 
InMemoryTableScan as HashPartitioning as if it had not been coalesced. The 
caller query of that InMemoryTableScan in turn failed to align the partitions 
correctly and output incorrect join results.
    
    The fix addresses the issue by disabling coalescing in InMemoryTableScan 
for shuffles in the final stage. This fix also guarantees that AQE enabled for 
SQL cache vs. disabled would always be a performance win, since AQE 
optimizations are applied to all non-top-level stages and meanwhile no extra 
shuffle would be introduced between the parent query and the cached relation 
(if coalescing in top-level shuffles of InMemoryTableScan was not disabled, an 
extra shuffle would end up being add [...]
    
    ### Why are the changes needed?
    
    To fix correctness issue and to avoid potential AQE perf regressions in 
queries using SQL cache.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43760 from maryannxue/spark-45592.
    
    Authored-by: Maryann Xue <maryann....@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 128f5523194d5241c7b0f08b5be183288128ba16)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  9 ++++
 .../apache/spark/sql/execution/CacheManager.scala  |  5 ++-
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  8 +++-
 .../org/apache/spark/sql/CachedTableSuite.scala    | 52 +++++++++++++++-------
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 33 +++++++++-----
 5 files changed, 79 insertions(+), 28 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4ea0cd5bcc12..70bd21ac1709 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -657,6 +657,15 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS =
+    buildConf("spark.sql.adaptive.applyFinalStageShuffleOptimizations")
+      .internal()
+      .doc("Configures whether adaptive query execution (if enabled) should 
apply shuffle " +
+        "coalescing and local shuffle read optimization for the final query 
stage.")
+      .version("3.4.2")
+      .booleanConf
+      .createWithDefault(true)
+
   val ADAPTIVE_EXECUTION_LOG_LEVEL = buildConf("spark.sql.adaptive.logLevel")
     .internal()
     .doc("Configures the log level for adaptive execution logging of plan 
changes. The value " +
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index e906c74f8a5e..9b79865149ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -402,8 +402,9 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
     if (session.conf.get(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) {
       // Bucketed scan only has one time overhead but can have multi-times 
benefits in cache,
       // so we always do bucketed scan in a cached plan.
-      SparkSession.getOrCloneSessionWithConfigsOff(
-        session, SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil)
+      SparkSession.getOrCloneSessionWithConfigsOff(session,
+        SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS ::
+          SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil)
     } else {
       SparkSession.getOrCloneSessionWithConfigsOff(session, 
forceDisableConfigs)
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 36895b17aa84..fa671c8faf8b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -159,7 +159,13 @@ case class AdaptiveSparkPlanExec(
   )
 
   private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): 
SparkPlan = {
-    val optimized = queryStageOptimizerRules.foldLeft(plan) { case 
(latestPlan, rule) =>
+    val rules = if (isFinalStage &&
+        
!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS))
 {
+      queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
+    } else {
+      queryStageOptimizerRules
+    }
+    val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
       val applied = rule.apply(latestPlan)
       val result = rule match {
         case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1e4a67347f5b..8331a3c10fc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils
 import org.apache.spark.CleanerListener
 import org.apache.spark.executor.DataReadMethod._
 import org.apache.spark.executor.DataReadMethod.DataReadMethod
-import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
@@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{ColumnarToRowExec, 
ExecSubqueryExpression
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
AQEPropagateEmptyRelation}
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import 
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
@@ -1623,23 +1624,44 @@ class CachedTableSuite extends QueryTest with 
SQLTestUtils
       SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
       SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
 
-      withTempView("t1", "t2", "t3") {
-        withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> 
"false") {
-          sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM values(1) as 
t(c)")
-          assert(spark.table("t1").rdd.partitions.length == 2)
+      var finalPlan = ""
+      val listener = new SparkListener {
+        override def onOtherEvent(event: SparkListenerEvent): Unit = {
+          event match {
+            case SparkListenerSQLAdaptiveExecutionUpdate(_, physicalPlanDesc, 
sparkPlanInfo) =>
+              if (sparkPlanInfo.simpleString.startsWith(
+                  "AdaptiveSparkPlan isFinalPlan=true")) {
+                finalPlan = physicalPlanDesc
+              }
+            case _ => // ignore other events
+          }
         }
+      }
 
-        withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> 
"true") {
-          assert(spark.table("t1").rdd.partitions.length == 2)
-          sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM values(2) as 
t(c)")
-          assert(spark.table("t2").rdd.partitions.length == 1)
-        }
+      withTempView("t0", "t1", "t2") {
+        try {
+          spark.range(10).write.saveAsTable("t0")
+          spark.sparkContext.listenerBus.waitUntilEmpty()
+          spark.sparkContext.addSparkListener(listener)
 
-        withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> 
"false") {
-          assert(spark.table("t1").rdd.partitions.length == 2)
-          assert(spark.table("t2").rdd.partitions.length == 1)
-          sql("CACHE TABLE t3 as SELECT /*+ REPARTITION */ * FROM values(3) as 
t(c)")
-          assert(spark.table("t3").rdd.partitions.length == 2)
+          withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key 
-> "false") {
+            sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM (" +
+              "SELECT distinct (id+1) FROM t0)")
+            assert(spark.table("t1").rdd.partitions.length == 2)
+            spark.sparkContext.listenerBus.waitUntilEmpty()
+            assert(finalPlan.nonEmpty && !finalPlan.contains("coalesced"))
+          }
+
+          finalPlan = "" // reset finalPlan
+          withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key 
-> "true") {
+            sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM (" +
+              "SELECT distinct (id-1) FROM t0)")
+            assert(spark.table("t2").rdd.partitions.length == 2)
+            spark.sparkContext.listenerBus.waitUntilEmpty()
+            assert(finalPlan.nonEmpty && finalPlan.contains("coalesced"))
+          }
+        } finally {
+          spark.sparkContext.removeSparkListener(listener)
         }
       }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0878ae134e9d..c2fe31520acf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2550,16 +2550,29 @@ class DatasetSuite extends QueryTest
   }
 
   test("SPARK-45592: Coaleasced shuffle read is not compatible with hash 
partitioning") {
-    val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF()
-      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
-    ee.count()
-
-    val minNbrs1 = ee
-      .groupBy("_1").agg(min(col("_2")).as("min_number"))
-      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
-
-    val join = ee.join(minNbrs1, "_1")
-    assert(join.count() == 1000000)
+    withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> 
"true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.SHUFFLE_PARTITIONS.key -> "20",
+      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000") {
+      val ee = spark.range(0, 1000, 1, 5).map(l => (l, l - 1)).toDF()
+        .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+      ee.count()
+
+      // `minNbrs1` will start with 20 partitions and without the fix would 
coalesce to ~10
+      // partitions.
+      val minNbrs1 = ee
+        .groupBy("_2").agg(min(col("_1")).as("min_number"))
+        .select(col("_2") as "_1", col("min_number"))
+        .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+      minNbrs1.count()
+
+      // shuffle on `ee` will start with 2 partitions, smaller than 
`minNbrs1`'s partition num,
+      // and `EnsureRequirements` will change its partition num to 
`minNbrs1`'s partition num.
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+        val join = ee.join(minNbrs1, "_1")
+        assert(join.count() == 999)
+      }
+    }
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to