This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 58b71307795b [SPARK-48037][CORE][3.5] Fix SortShuffleWriter lacks 
shuffle write related metrics resulting in potentially inaccurate data
58b71307795b is described below

commit 58b71307795b6060be97431e0c5c8ab95205ea79
Author: sychen <syc...@ctrip.com>
AuthorDate: Tue May 7 22:39:02 2024 -0700

    [SPARK-48037][CORE][3.5] Fix SortShuffleWriter lacks shuffle write related 
metrics resulting in potentially inaccurate data
    
    ### What changes were proposed in this pull request?
    This PR aims to fix SortShuffleWriter lacks shuffle write related metrics 
resulting in potentially inaccurate data.
    
    ### Why are the changes needed?
    When the shuffle writer is SortShuffleWriter, it does not use 
SQLShuffleWriteMetricsReporter to update metrics, which causes AQE to obtain 
runtime statistics and the rowCount obtained is 0.
    
    Some optimization rules rely on rowCount statistics, such as 
`EliminateLimits`. Because rowCount is 0, it removes the limit operator. At 
this time, we get data results without limit.
    
    
https://github.com/apache/spark/blob/59d5946cfd377e9203ccf572deb34f87fab7510c/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala#L168-L172
    
    
https://github.com/apache/spark/blob/59d5946cfd377e9203ccf572deb34f87fab7510c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L2067-L2070
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    Production environment verification.
    
    **master metrics**
    <img width="296" alt="image" 
src="https://github.com/apache/spark/assets/3898450/dc9b6e8a-93ec-4f59-a903-71aa5b11962c";>
    
    **PR metrics**
    
    <img width="276" alt="image" 
src="https://github.com/apache/spark/assets/3898450/2d73b773-2dcc-4d23-81de-25dcadac86c1";>
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46459 from cxzl25/SPARK-48037-3.5.
    
    Authored-by: sychen <syc...@ctrip.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/shuffle/sort/SortShuffleManager.scala    |  2 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala     |  6 ++--
 .../spark/util/collection/ExternalSorter.scala     |  9 +++---
 .../shuffle/sort/SortShuffleWriterSuite.scala      |  3 ++
 .../sql/execution/UnsafeRowSerializerSuite.scala   |  3 +-
 .../adaptive/AdaptiveQueryExecSuite.scala          | 32 ++++++++++++++++++++--
 6 files changed, 43 insertions(+), 12 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 46aca07ce43f..79dff6f87534 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -176,7 +176,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) 
extends ShuffleManager
           metrics,
           shuffleExecutorComponents)
       case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
-        new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
+        new SortShuffleWriter(other, mapId, context, metrics, 
shuffleExecutorComponents)
     }
   }
 
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 8613fe11a4c2..3be7d24f7e4e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -21,6 +21,7 @@ import org.apache.spark._
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
 import org.apache.spark.util.collection.ExternalSorter
 
@@ -28,6 +29,7 @@ private[spark] class SortShuffleWriter[K, V, C](
     handle: BaseShuffleHandle[K, V, C],
     mapId: Long,
     context: TaskContext,
+    writeMetrics: ShuffleWriteMetricsReporter,
     shuffleExecutorComponents: ShuffleExecutorComponents)
   extends ShuffleWriter[K, V] with Logging {
 
@@ -46,8 +48,6 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var partitionLengths: Array[Long] = _
 
-  private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
-
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[Product2[K, V]]): Unit = {
     sorter = if (dep.mapSideCombine) {
@@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
     // (see SPARK-3570).
     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
       dep.shuffleId, mapId, dep.partitioner.numPartitions)
-    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
+    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, 
writeMetrics)
     partitionLengths = 
mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
   }
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7153bb72476a..2f2734a389ff 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.serializer._
-import org.apache.spark.shuffle.ShufflePartitionPairsWriter
+import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, 
ShuffleWriteMetricsReporter}
 import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, 
ShufflePartitionWriter}
 import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, 
ShuffleBlockId}
@@ -696,7 +696,8 @@ private[spark] class ExternalSorter[K, V, C](
   def writePartitionedMapOutput(
       shuffleId: Int,
       mapId: Long,
-      mapOutputWriter: ShuffleMapOutputWriter): Unit = {
+      mapOutputWriter: ShuffleMapOutputWriter,
+      writeMetrics: ShuffleWriteMetricsReporter): Unit = {
     if (spills.isEmpty) {
       // Case where we only have in-memory data
       val collection = if (aggregator.isDefined) map else buffer
@@ -713,7 +714,7 @@ private[spark] class ExternalSorter[K, V, C](
             serializerManager,
             serInstance,
             blockId,
-            context.taskMetrics().shuffleWriteMetrics,
+            writeMetrics,
             if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) 
else null)
           while (it.hasNext && it.nextPartition() == partitionId) {
             it.writeNext(partitionPairsWriter)
@@ -737,7 +738,7 @@ private[spark] class ExternalSorter[K, V, C](
             serializerManager,
             serInstance,
             blockId,
-            context.taskMetrics().shuffleWriteMetrics,
+            writeMetrics,
             if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
           if (elements.hasNext) {
             for (elem <- elements) {
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 9e52b5e15143..99402abb16ca 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -85,6 +85,7 @@ class SortShuffleWriterSuite
       shuffleHandle,
       mapId = 1,
       context,
+      context.taskMetrics().shuffleWriteMetrics,
       shuffleExecutorComponents)
     writer.write(Iterator.empty)
     writer.stop(success = true)
@@ -102,6 +103,7 @@ class SortShuffleWriterSuite
       shuffleHandle,
       mapId = 2,
       context,
+      context.taskMetrics().shuffleWriteMetrics,
       shuffleExecutorComponents)
     writer.write(records.iterator)
     writer.stop(success = true)
@@ -158,6 +160,7 @@ class SortShuffleWriterSuite
         shuffleHandle,
         mapId = 0,
         context,
+        context.taskMetrics().shuffleWriteMetrics,
         new LocalDiskShuffleExecutorComponents(
           conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
       writer.write(records.iterator)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index d94934210615..928d732f2a16 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -130,7 +130,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with 
LocalSparkSession {
     assert(sorter.numSpills > 0)
 
     // Merging spilled files should not throw assertion error
-    sorter.writePartitionedMapOutput(0, 0, mapOutputWriter)
+    sorter.writePartitionedMapOutput(0, 0, mapOutputWriter,
+      taskContext.taskMetrics.shuffleWriteMetrics)
   }
 
   test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 7c280f72ca17..cab3e69b0d17 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkException
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
+import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
@@ -60,7 +61,8 @@ class AdaptiveQueryExecSuite
 
   setupTestData()
 
-  private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, 
SparkPlan) = {
+  private def runAdaptiveAndVerifyResult(query: String,
+      skipCheckAnswer: Boolean = false): (SparkPlan, SparkPlan) = {
     var finalPlanCnt = 0
     var hasMetricsEvent = false
     val listener = new SparkListener {
@@ -84,8 +86,10 @@ class AdaptiveQueryExecSuite
     assert(planBefore.toString.startsWith("AdaptiveSparkPlan 
isFinalPlan=false"))
     val result = dfAdaptive.collect()
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
-      val df = sql(query)
-      checkAnswer(df, result)
+      if (!skipCheckAnswer) {
+        val df = sql(query)
+        checkAnswer(df, result)
+      }
     }
     val planAfter = dfAdaptive.queryExecution.executedPlan
     assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true"))
@@ -2405,6 +2409,28 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics 
" +
+    "resulting in potentially inaccurate data") {
+    withTable("t3") {
+      withSQLConf(
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+        SQLConf.SHUFFLE_PARTITIONS.key -> (SortShuffleManager
+          .MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1).toString) {
+        sql("CREATE TABLE t3 USING PARQUET AS SELECT id FROM range(2)")
+        val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+          """
+            |SELECT id, count(*)
+            |FROM t3
+            |GROUP BY id
+            |LIMIT 1
+            |""".stripMargin, skipCheckAnswer = true)
+        // The shuffle stage produces two rows and the limit operator should 
not been optimized out.
+        assert(findTopLevelLimit(plan).size == 1)
+        assert(findTopLevelLimit(adaptivePlan).size == 1)
+      }
+    }
+  }
+
   test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize 
non-root node") {
     withTempView("v") {
       withSQLConf(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to