This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cc9a158  [SPARK-37726][SQL] Add spill size metrics for sort merge join
cc9a158 is described below

commit cc9a158712d7b46382885479dbacb46a253ca800
Author: Cheng Su <chen...@fb.com>
AuthorDate: Thu Dec 30 12:12:30 2021 +0800

    [SPARK-37726][SQL] Add spill size metrics for sort merge join
    
    ### What changes were proposed in this pull request?
    
    Sort merge join allows buffered side to spill if the size is too large to 
hold in memory. It would be good to add a "spill size" SQL metrics in sort 
merge join, to track how often the spill happens, and how much of spill size 
would be in case when it spills.
    
    ### Why are the changes needed?
    
    This helps to get more insights from query when the spill happens. Also 
help us decide whether to use Spark code-gen engine vs customized columnar 
engine, in case customized engine not support spill for now.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes the SQL metrics "spill size" itself on Spark UI.
    
    ### How was this patch tested?
    
    * Modified existing unit test in `SQLMetricsSuite.scala`.
    * Tested simple query with `spark-shell` locally:
    <img width="808" alt="Screen Shot 2021-12-23 at 5 21 20 PM" 
src="https://user-images.githubusercontent.com/4629931/147304391-48890d7c-9701-4112-8eda-da41fc7cacee.png";>
    
    Closes #34999 from c21/smj-spill.
    
    Authored-by: Cheng Su <chen...@fb.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../ExternalAppendOnlyUnsafeRowArray.scala         | 13 ++++++
 .../sql/execution/joins/SortMergeJoinExec.scala    | 46 +++++++++++++++++++++-
 .../sql/execution/metric/SQLMetricsSuite.scala     |  7 ++++
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
index 4a064ef..2c9c91e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -73,6 +73,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
   }
 
   private var spillableArray: UnsafeExternalSorter = _
+  private var totalSpillBytes: Long = 0
   private var numRows = 0
 
   // A counter to keep track of total modifications done to this array since 
its creation.
@@ -86,10 +87,22 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
   def isEmpty: Boolean = numRows == 0
 
   /**
+   * Total number of bytes that has been spilled into disk so far.
+   */
+  def spillSize: Long = {
+    if (spillableArray != null) {
+      totalSpillBytes + spillableArray.getSpillSize
+    } else {
+      totalSpillBytes
+    }
+  }
+
+  /**
    * Clears up resources (e.g. memory) held by the backing storage
    */
   def clear(): Unit = {
     if (spillableArray != null) {
+      totalSpillBytes += spillableArray.getSpillSize
       // The last `spillableArray` of this task will be cleaned up via task 
completion listener
       // inside `UnsafeExternalSorter`
       spillableArray.cleanupResources()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 76767cc..ad2e179 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -45,7 +46,8 @@ case class SortMergeJoinExec(
     isSkewJoin: Boolean = false) extends ShuffledJoin {
 
   override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
 
   override def outputOrdering: Seq[SortOrder] = joinType match {
     // For inner join, orders of both sides keys should be kept.
@@ -123,6 +125,7 @@ case class SortMergeJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillSize = longMetric("spillSize")
     val spillThreshold = getSpillThreshold
     val inMemoryThreshold = getInMemoryThreshold
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
@@ -152,6 +155,7 @@ case class SortMergeJoinExec(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
+              spillSize,
               cleanupResources
             )
             private[this] val joinRow = new JoinedRow
@@ -197,6 +201,7 @@ case class SortMergeJoinExec(
             bufferedIter = RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold,
+            spillSize,
             cleanupResources
           )
           val rightNullRow = new GenericInternalRow(right.output.length)
@@ -212,6 +217,7 @@ case class SortMergeJoinExec(
             bufferedIter = RowIterator.fromScala(leftIter),
             inMemoryThreshold,
             spillThreshold,
+            spillSize,
             cleanupResources
           )
           val leftNullRow = new GenericInternalRow(left.output.length)
@@ -247,6 +253,7 @@ case class SortMergeJoinExec(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
+              spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow
             )
@@ -284,6 +291,7 @@ case class SortMergeJoinExec(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
+              spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow
             )
@@ -328,6 +336,7 @@ case class SortMergeJoinExec(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
+              spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow
             )
@@ -648,6 +657,13 @@ case class SortMergeJoinExec(
 
   override def needCopyResult: Boolean = true
 
+  /**
+   * This is called by generated Java class, should be public.
+   */
+  def getTaskContext(): TaskContext = {
+    TaskContext.get()
+  }
+
   override def doProduce(ctx: CodegenContext): String = {
     // Specialize `doProduce` code for full outer join, because full outer 
join needs to
     // buffer both sides of join.
@@ -766,7 +782,7 @@ case class SortMergeJoinExec(
     val thisPlan = ctx.addReferenceObj("plan", this)
     val eagerCleanup = s"$thisPlan.cleanupResources();"
 
-    joinType match {
+    val doJoin = joinType match {
       case _: InnerLike =>
         codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, 
condCheck, outputRow,
           eagerCleanup)
@@ -786,6 +802,26 @@ case class SortMergeJoinExec(
         throw new IllegalArgumentException(
           s"SortMergeJoin.doProduce should not take $x as the JoinType")
     }
+
+    val initJoin = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initJoin")
+    val addHookToRecordMetrics =
+      s"""
+         |$thisPlan.getTaskContext().addTaskCompletionListener(
+         |  new org.apache.spark.util.TaskCompletionListener() {
+         |    @Override
+         |    public void onTaskCompletion(org.apache.spark.TaskContext 
context) {
+         |      ${metricTerm(ctx, "spillSize")}.add($matches.spillSize());
+         |    }
+         |});
+       """.stripMargin
+
+    s"""
+       |if (!$initJoin) {
+       |  $initJoin = true;
+       |  $addHookToRecordMetrics
+       |}
+       |$doJoin
+     """.stripMargin
   }
 
   /**
@@ -1231,6 +1267,7 @@ private[joins] class SortMergeJoinScanner(
     bufferedIter: RowIterator,
     inMemoryThreshold: Int,
     spillThreshold: Int,
+    spillSize: SQLMetric,
     eagerCleanupResources: () => Unit,
     onlyBufferFirstMatch: Boolean = false) {
   private[this] var streamedRow: InternalRow = _
@@ -1246,6 +1283,11 @@ private[joins] class SortMergeJoinScanner(
   private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
     new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
 
+  // At the end of the task, update the task's spill size for buffered side.
+  TaskContext.get().addTaskCompletionListener[Unit](_ => {
+    spillSize += bufferedMatches.spillSize
+  })
+
   // Initialization (note: do _not_ want to advance streamed here).
   advancedBufferedToRowWithNullFreeJoinKey()
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 162531ba..0fd5c89 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -300,6 +300,13 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
             "shuffle records written" -> 2L)))),
           enableWholeStage
         )
+        testSparkPlanMetricsWithPredicates(df, 1, Map(
+          nodeId1 -> (("SortMergeJoin", Map(
+            "spill size" -> {
+              _.toString.matches(sizeMetricPattern)
+            })))),
+          enableWholeStage
+        )
       }
     }
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to