This is an automated email from the ASF dual-hosted git repository.

wuyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 336ca8c12a16 [SPARK-49386][CORE][SQL][FOLLOWUP] More accurate memory 
tracking for memory based spill threshold
336ca8c12a16 is described below

commit 336ca8c12a163d31519a491a28d26ff091626986
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Sep 4 13:24:21 2025 +0800

    [SPARK-49386][CORE][SQL][FOLLOWUP] More accurate memory tracking for memory 
based spill threshold
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/47856 . It makes 
the memory tracking more accurate in several places:
    1. In `ShuffleExternalSorter`/`UnsafeExternalSorter`, the memory is used by 
both the sorter itself, and its underlying in-memort sorter (for sorting 
shuffle partition ids). We need to add them up to calcuate the current memory 
usage.
    2. In `ExternalAppendOnlyUnsafeRowArray`, the records are inserted to an 
in-memory buffer first. If the buffer gets too large (currently based on num 
records), we switch to `UnsafeExternalSorter`. The in-memory buffer also needs 
a memory based threshold
    
    ### Why are the changes needed?
    
    More accurate memory tracking results to better spill decisions
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the feature is not released yet.
    
    ### How was this patch tested?
    
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #52190 from cloud-fan/spill.
    
    Lead-authored-by: Wenchen Fan <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Yi Wu <[email protected]>
---
 .../spark/shuffle/sort/ShuffleExternalSorter.java  | 33 +++++++++-------
 .../unsafe/sort/UnsafeExternalSorter.java          | 44 +++++++++++----------
 .../org/apache/spark/internal/config/package.scala |  4 +-
 .../sql/execution/UnsafeKVExternalSorter.java      | 10 ++---
 .../ExternalAppendOnlyUnsafeRowArray.scala         | 46 ++++++++++++++--------
 .../aggregate/UpdatingSessionsIterator.scala       | 11 ++++--
 .../sql/execution/joins/CartesianProductExec.scala | 11 ++++--
 .../joins/SortMergeJoinEvaluatorFactory.scala      | 14 +++----
 .../sql/execution/joins/SortMergeJoinExec.scala    | 26 ++++++++----
 .../python/ArrowAggregatePythonExec.scala          |  4 +-
 .../python/ArrowWindowPythonEvaluatorFactory.scala | 11 ++++--
 .../execution/window/WindowEvaluatorFactory.scala  | 11 ++++--
 ...ExternalAppendOnlyUnsafeRowArrayBenchmark.scala |  5 ++-
 .../ExternalAppendOnlyUnsafeRowArraySuite.scala    |  1 +
 14 files changed, 142 insertions(+), 89 deletions(-)

diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index bed241b1e03a..f053135c4dbd 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -90,9 +90,9 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
   private final int numElementsForSpillThreshold;
 
   /**
-   * Force this sorter to spill when the size in memory is beyond this 
threshold.
+   * Force this sorter to spill when the in memory size in bytes is beyond 
this threshold.
    */
-  private final long recordsSizeForSpillThreshold;
+  private final long sizeInBytesForSpillThreshold;
 
   /** The buffer size to use when writing spills using DiskBlockObjectWriter */
   private final int fileBufferSizeBytes;
@@ -117,7 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
   @Nullable private ShuffleInMemorySorter inMemSorter;
   @Nullable private MemoryBlock currentPage = null;
   private long pageCursor = -1;
-  private long inMemRecordsSize = 0;
+  private long totalPageMemoryUsageBytes = 0;
 
   // Checksum calculator for each partition. Empty when shuffle checksum 
disabled.
   private final Checksum[] partitionChecksums;
@@ -142,7 +142,7 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
         (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 
1024;
     this.numElementsForSpillThreshold =
         (int) 
conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD());
-    this.recordsSizeForSpillThreshold =
+    this.sizeInBytesForSpillThreshold =
         (long) 
conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD());
     this.writeMetrics = writeMetrics;
     this.inMemSorter = new ShuffleInMemorySorter(
@@ -314,11 +314,7 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
   }
 
   private long getMemoryUsage() {
-    long totalPageSize = 0;
-    for (MemoryBlock page : allocatedPages) {
-      totalPageSize += page.size();
-    }
-    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageMemoryUsageBytes;
   }
 
   private void updatePeakMemoryUsed() {
@@ -342,11 +338,11 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
     for (MemoryBlock block : allocatedPages) {
       memoryFreed += block.size();
       freePage(block);
+      totalPageMemoryUsageBytes -= block.size();
     }
     allocatedPages.clear();
     currentPage = null;
     pageCursor = 0;
-    inMemRecordsSize = 0;
     return memoryFreed;
   }
 
@@ -417,6 +413,7 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
       currentPage = allocatePage(required);
       pageCursor = currentPage.getBaseOffset();
       allocatedPages.add(currentPage);
+      totalPageMemoryUsageBytes += currentPage.size();
     }
   }
 
@@ -432,10 +429,17 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
         MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()),
         MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, 
numElementsForSpillThreshold));
       spill();
-    } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
-      logger.info("Spilling data because size of spilledRecords ({}) crossed 
the size threshold {}",
-        MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize),
-        MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, 
recordsSizeForSpillThreshold));
+    }
+
+    // TODO: Ideally we only need to check the spill threshold when new memory 
needs to be
+    //       allocated (both this sorter and the underlying 
ShuffleInMemorySorter may allocate
+    //       new memory), but it's simpler to check the total memory usage of 
these two sorters
+    //       before inserting each record.
+    final long usedMemory = getMemoryUsage();
+    if (usedMemory >= sizeInBytesForSpillThreshold) {
+      logger.info("Spilling data because memory usage ({}) crossed the 
threshold {}",
+        MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory),
+        MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, 
sizeInBytesForSpillThreshold));
       spill();
     }
 
@@ -453,7 +457,6 @@ final class ShuffleExternalSorter extends MemoryConsumer 
implements ShuffleCheck
     Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
     pageCursor += length;
     inMemSorter.insertRecord(recordAddress, partitionId);
-    inMemRecordsSize += required;
   }
 
   /**
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 94c37e187131..71a826642b1b 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -80,9 +80,9 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
   private final int numElementsForSpillThreshold;
 
   /**
-   * Force this sorter to spill when the size in memory is beyond this 
threshold.
+   * Force this sorter to spill when the in memory size in bytes is beyond 
this threshold.
    */
-  private final long recordsSizeForSpillThreshold;
+  private final long sizeInBytesForSpillThreshold;
 
   /**
    * Memory pages that hold the records being sorted. The pages in this list 
are freed when
@@ -96,7 +96,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
 
   // These variables are reset after spilling:
   @Nullable private volatile UnsafeInMemorySorter inMemSorter;
-  private long inMemRecordsSize = 0;
+  private long totalPageMemoryUsageBytes = 0;
 
   private MemoryBlock currentPage = null;
   private long pageCursor = -1;
@@ -115,12 +115,12 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       int initialSize,
       long pageSizeBytes,
       int numElementsForSpillThreshold,
-      long recordsSizeForSpillThreshold,
+      long sizeInBytesForSpillThreshold,
       UnsafeInMemorySorter inMemorySorter,
       long existingMemoryConsumption) throws IOException {
     UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, 
blockManager,
       serializerManager, taskContext, recordComparatorSupplier, 
prefixComparator, initialSize,
-        pageSizeBytes, numElementsForSpillThreshold, 
recordsSizeForSpillThreshold,
+        pageSizeBytes, numElementsForSpillThreshold, 
sizeInBytesForSpillThreshold,
         inMemorySorter, false /* ignored */);
     sorter.spill(Long.MAX_VALUE, sorter);
     taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption);
@@ -140,11 +140,11 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       int initialSize,
       long pageSizeBytes,
       int numElementsForSpillThreshold,
-      long recordsSizeForSpillThreshold,
+      long sizeInBytesForSpillThreshold,
       boolean canUseRadixSort) {
     return new UnsafeExternalSorter(taskMemoryManager, blockManager, 
serializerManager,
       taskContext, recordComparatorSupplier, prefixComparator, initialSize, 
pageSizeBytes,
-      numElementsForSpillThreshold, recordsSizeForSpillThreshold, null, 
canUseRadixSort);
+      numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null, 
canUseRadixSort);
   }
 
   private UnsafeExternalSorter(
@@ -157,7 +157,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       int initialSize,
       long pageSizeBytes,
       int numElementsForSpillThreshold,
-      long recordsSizeForSpillThreshold,
+      long sizeInBytesForSpillThreshold,
       @Nullable UnsafeInMemorySorter existingInMemorySorter,
       boolean canUseRadixSort) {
     super(taskMemoryManager, pageSizeBytes, 
taskMemoryManager.getTungstenMemoryMode());
@@ -187,7 +187,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       this.inMemSorter = existingInMemorySorter;
     }
     this.peakMemoryUsedBytes = getMemoryUsage();
-    this.recordsSizeForSpillThreshold = recordsSizeForSpillThreshold;
+    this.sizeInBytesForSpillThreshold = sizeInBytesForSpillThreshold;
     this.numElementsForSpillThreshold = numElementsForSpillThreshold;
 
     // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
@@ -248,7 +248,6 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     // pages will currently be counted as memory spilled even though that 
space isn't actually
     // written to disk. This also counts the space needed to store the 
sorter's pointer array.
     inMemSorter.freeMemory();
-    inMemRecordsSize = 0;
     // Reset the in-memory sorter's pointer array only after freeing up the 
memory pages holding the
     // records. Otherwise, if the task is over allocated memory, then without 
freeing the memory
     // pages, we might not be able to get memory for the pointer array.
@@ -264,11 +263,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
    * array.
    */
   private long getMemoryUsage() {
-    long totalPageSize = 0;
-    for (MemoryBlock page : allocatedPages) {
-      totalPageSize += page.size();
-    }
-    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + 
totalPageMemoryUsageBytes;
   }
 
   private void updatePeakMemoryUsed() {
@@ -320,6 +315,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     for (MemoryBlock block : pagesToFree) {
       memoryFreed += block.size();
       freePage(block);
+      totalPageMemoryUsageBytes -= block.size();
     }
     return memoryFreed;
   }
@@ -378,6 +374,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     } finally {
       for (MemoryBlock pageToFree : pagesToFree) {
         freePage(pageToFree);
+        totalPageMemoryUsageBytes -= pageToFree.size();
       }
       if (inMemSorterToFree != null) {
         inMemSorterToFree.freeMemory();
@@ -448,6 +445,7 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       currentPage = allocatePage(required);
       pageCursor = currentPage.getBaseOffset();
       allocatedPages.add(currentPage);
+      totalPageMemoryUsageBytes += currentPage.size();
     }
   }
 
@@ -495,10 +493,17 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
         MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()),
         MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, 
numElementsForSpillThreshold));
       spill();
-    } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) {
-      logger.info("Spilling data because size of spilledRecords ({}) crossed 
the size threshold {}",
-        MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize),
-        MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, 
recordsSizeForSpillThreshold));
+    }
+
+    // TODO: Ideally we only need to check the spill threshold when new memory 
needs to be
+    //       allocated (both this sorter and the underlying 
UnsafeInMemorySorter may allocate
+    //       new memory), but it's simpler to check the total memory usage of 
these two sorters
+    //       before inserting each record.
+    final long usedMemory = getMemoryUsage();
+    if (usedMemory >= sizeInBytesForSpillThreshold) {
+      logger.info("Spilling data because memory usage ({}) crossed the 
threshold {}",
+        MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory),
+        MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, 
sizeInBytesForSpillThreshold));
       spill();
     }
 
@@ -514,7 +519,6 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
     pageCursor += length;
     inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
-    inMemRecordsSize += required;
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 0bee708bca3c..120948064f92 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1599,9 +1599,9 @@ package object config {
       .createWithDefault(Integer.MAX_VALUE)
 
   private[spark] val SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD =
-    ConfigBuilder("spark.shuffle.spill.maxRecordsSizeForSpillThreshold")
+    ConfigBuilder("spark.shuffle.spill.maxSizeInBytesForSpillThreshold")
       .internal()
-      .doc("The maximum size in memory before forcing the shuffle sorter to 
spill. " +
+      .doc("The maximum in memory size in bytes before forcing the shuffle 
sorter to spill. " +
         "By default it is Long.MAX_VALUE, which means we never force the 
sorter to spill, " +
         "until we reach some limitations, like the max page size limitation 
for the pointer " +
         "array in the sorter.")
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 6affcb61b8d6..ca49c5f306ca 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -61,9 +61,9 @@ public final class UnsafeKVExternalSorter {
       SerializerManager serializerManager,
       long pageSizeBytes,
       int numElementsForSpillThreshold,
-      long maxRecordsSizeForSpillThreshold) throws IOException {
+      long sizeInBytesForSpillThreshold) throws IOException {
     this(keySchema, valueSchema, blockManager, serializerManager, 
pageSizeBytes,
-      numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null);
+      numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null);
   }
 
   public UnsafeKVExternalSorter(
@@ -73,7 +73,7 @@ public final class UnsafeKVExternalSorter {
       SerializerManager serializerManager,
       long pageSizeBytes,
       int numElementsForSpillThreshold,
-      long maxRecordsSizeForSpillThreshold,
+      long sizeInBytesForSpillThreshold,
       @Nullable BytesToBytesMap map) throws IOException {
     this.keySchema = keySchema;
     this.valueSchema = valueSchema;
@@ -100,7 +100,7 @@ public final class UnsafeKVExternalSorter {
         (int) (long) 
SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
         pageSizeBytes,
         numElementsForSpillThreshold,
-        maxRecordsSizeForSpillThreshold,
+        sizeInBytesForSpillThreshold,
         canUseRadixSort);
     } else {
       // During spilling, the pointer array in `BytesToBytesMap` will not be 
used, so we can borrow
@@ -168,7 +168,7 @@ public final class UnsafeKVExternalSorter {
         (int) (long) 
SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
         pageSizeBytes,
         numElementsForSpillThreshold,
-        maxRecordsSizeForSpillThreshold,
+        sizeInBytesForSpillThreshold,
         inMemSorter,
         map.getTotalMemoryConsumption());
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
index 3e98c28b29fb..e8d5858b04fe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.internal.Logging
-import org.apache.spark.internal.LogKeys.{CLASS_NAME, 
MAX_NUM_ROWS_IN_MEMORY_BUFFER}
+import org.apache.spark.internal.LogKeys.{CLASS_NAME, 
MAX_NUM_ROWS_IN_MEMORY_BUFFER, NUM_BYTES_MAX}
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -34,17 +34,18 @@ import 
org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf
 
 /**
  * An append-only array for [[UnsafeRow]]s that strictly keeps content in an 
in-memory array
- * until [[numRowsInMemoryBufferThreshold]] is reached post which it will 
switch to a mode which
- * would flush to disk after [[numRowsSpillThreshold]] is met (or before if 
there is
- * excessive memory consumption). Setting these threshold involves following 
trade-offs:
+ * until [[numRowsInMemoryBufferThreshold]] or 
[[sizeInBytesInMemoryBufferThreshold]] is reached
+ * post which it will switch to a mode (backed by [[UnsafeExternalSorter]]) 
which would flush to
+ * disk after [[numRowsSpillThreshold]] or [[sizeInBytesSpillThreshold]] is 
met (or before if there
+ * is excessive memory consumption). Setting these threshold involves 
following trade-offs:
  *
- * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array 
may occupy more memory
- *   than is available, resulting in OOM.
- * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently 
and lead to
- *   excessive disk writes. This may lead to a performance regression compared 
to the normal case
- *   of using an [[ArrayBuffer]] or [[Array]].
+ * - If [[numRowsInMemoryBufferThreshold]] and 
[[sizeInBytesInMemoryBufferThreshold]] are too high,
+ *   the in-memory array may occupy more memory than is available, resulting 
in OOM.
+ * - If [[numRowsSpillThreshold]] or [[sizeInBytesSpillThreshold]] is too low, 
data will be spilled
+ *   frequently and lead to excessive disk writes. This may lead to a 
performance regression
+ *   compared to the normal case of using an [[ArrayBuffer]] or [[Array]].
  */
-private[sql] class ExternalAppendOnlyUnsafeRowArray(
+class ExternalAppendOnlyUnsafeRowArray(
     taskMemoryManager: TaskMemoryManager,
     blockManager: BlockManager,
     serializerManager: SerializerManager,
@@ -52,12 +53,15 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
     initialSize: Int,
     pageSizeBytes: Long,
     numRowsInMemoryBufferThreshold: Int,
+    sizeInBytesInMemoryBufferThreshold: Long,
     numRowsSpillThreshold: Int,
-    maxSizeSpillThreshold: Long) extends Logging {
+    sizeInBytesSpillThreshold: Long) extends Logging {
 
-  def this(numRowsInMemoryBufferThreshold: Int,
-    numRowsSpillThreshold: Int,
-    maxSizeSpillThreshold: Long) = {
+  def this(
+      numRowsInMemoryBufferThreshold: Int,
+      sizeInBytesInMemoryBufferThreshold: Long,
+      numRowsSpillThreshold: Int,
+      sizeInBytesSpillThreshold: Long) = {
     this(
       TaskContext.get().taskMemoryManager(),
       SparkEnv.get.blockManager,
@@ -66,8 +70,9 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
       1024,
       SparkEnv.get.memoryManager.pageSizeBytes,
       numRowsInMemoryBufferThreshold,
+      sizeInBytesInMemoryBufferThreshold,
       numRowsSpillThreshold,
-      maxSizeSpillThreshold)
+      sizeInBytesSpillThreshold)
   }
 
   private val initialSizeOfInMemoryBuffer =
@@ -78,6 +83,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
   } else {
     null
   }
+  private var inMemoryBufferSizeInBytes = 0L
 
   private var spillableArray: UnsafeExternalSorter = _
   private var totalSpillBytes: Long = 0
@@ -116,6 +122,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
       spillableArray = null
     } else if (inMemoryBuffer != null) {
       inMemoryBuffer.clear()
+      inMemoryBufferSizeInBytes = 0;
     }
     numFieldsPerRow = 0
     numRows = 0
@@ -123,12 +130,16 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
   }
 
   def add(unsafeRow: UnsafeRow): Unit = {
-    if (numRows < numRowsInMemoryBufferThreshold) {
+    // Once spills, we will switch to UnsafeExternalSorter permanently.
+    if (spillableArray == null && numRows < numRowsInMemoryBufferThreshold &&
+      inMemoryBufferSizeInBytes < sizeInBytesInMemoryBufferThreshold) {
       inMemoryBuffer += unsafeRow.copy()
+      inMemoryBufferSizeInBytes += unsafeRow.getSizeInBytes
     } else {
       if (spillableArray == null) {
         logInfo(log"Reached spill threshold of " +
           log"${MDC(MAX_NUM_ROWS_IN_MEMORY_BUFFER, 
numRowsInMemoryBufferThreshold)} rows, " +
+          log"or ${MDC(NUM_BYTES_MAX, sizeInBytesInMemoryBufferThreshold)} 
bytes, " +
           log"switching to ${MDC(CLASS_NAME, 
classOf[UnsafeExternalSorter].getName)}")
 
         // We will not sort the rows, so prefixComparator and recordComparator 
are null
@@ -142,7 +153,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
           initialSize,
           pageSizeBytes,
           numRowsSpillThreshold,
-          maxSizeSpillThreshold,
+          sizeInBytesSpillThreshold,
           false)
 
         // populate with existing in-memory buffered rows
@@ -156,6 +167,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
               false)
           )
           inMemoryBuffer.clear()
+          inMemoryBufferSizeInBytes = 0
         }
         numFieldsPerRow = unsafeRow.numFields()
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala
index 64bb3717f52b..cf146889912d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala
@@ -44,7 +44,7 @@ class UpdatingSessionsIterator(
     inputSchema: Seq[Attribute],
     inMemoryThreshold: Int,
     spillThreshold: Int,
-    spillSizeThreshold: Long) extends Iterator[InternalRow] {
+    sizeInBytesSpillThreshold: Long) extends Iterator[InternalRow] {
 
   private val groupingWithoutSession: Seq[NamedExpression] =
     groupingExpressions.diff(Seq(sessionExpression))
@@ -151,8 +151,13 @@ class UpdatingSessionsIterator(
     currentKeys = groupingKey.copy()
     currentSession = sessionStruct.copy()
 
-    rowsForCurrentSession = new 
ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold,
-      spillSizeThreshold)
+    rowsForCurrentSession = new ExternalAppendOnlyUnsafeRowArray(
+      inMemoryThreshold,
+      // TODO: shall we have a new config to specify the max in-memory buffer 
size
+      //       of ExternalAppendOnlyUnsafeRowArray?
+      sizeInBytesSpillThreshold,
+      spillThreshold,
+      sizeInBytesSpillThreshold)
     rowsForCurrentSession.add(currentRow.asInstanceOf[UnsafeRow])
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8065decb0dff..a8523746f9d2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -37,12 +37,17 @@ class UnsafeCartesianRDD(
     right : RDD[UnsafeRow],
     inMemoryBufferThreshold: Int,
     spillThreshold: Int,
-    spillSizeThreshold: Long)
+    sizeInBytesSpillThreshold: Long)
   extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[(UnsafeRow, UnsafeRow)] = {
-    val rowArray = new 
ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold,
-      spillSizeThreshold)
+    val rowArray = new ExternalAppendOnlyUnsafeRowArray(
+      inMemoryBufferThreshold,
+      // TODO: shall we have a new config to specify the max in-memory buffer 
size
+      //       of ExternalAppendOnlyUnsafeRowArray?
+      sizeInBytesSpillThreshold,
+      spillThreshold,
+      sizeInBytesSpillThreshold)
 
     val partition = split.asInstanceOf[CartesianPartition]
     rdd2.iterator(partition.s2, context).foreach(rowArray.add)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala
index b4e52ba050b8..2b6a19dfa8a8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala
@@ -34,7 +34,7 @@ class SortMergeJoinEvaluatorFactory(
     output: Seq[Attribute],
     inMemoryThreshold: Int,
     spillThreshold: Int,
-    spillSizeThreshold: Long,
+    sizeInBytesSpillThreshold: Long,
     numOutputRows: SQLMetric,
     spillSize: SQLMetric,
     onlyBufferFirstMatchedRow: Boolean)
@@ -86,7 +86,7 @@ class SortMergeJoinEvaluatorFactory(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
-              spillSizeThreshold,
+              sizeInBytesSpillThreshold,
               spillSize,
               cleanupResources)
             private[this] val joinRow = new JoinedRow
@@ -132,7 +132,7 @@ class SortMergeJoinEvaluatorFactory(
             bufferedIter = RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold,
-            spillSizeThreshold,
+            sizeInBytesSpillThreshold,
             spillSize,
             cleanupResources)
           val rightNullRow = new GenericInternalRow(right.output.length)
@@ -152,7 +152,7 @@ class SortMergeJoinEvaluatorFactory(
             bufferedIter = RowIterator.fromScala(leftIter),
             inMemoryThreshold,
             spillThreshold,
-            spillSizeThreshold,
+            sizeInBytesSpillThreshold,
             spillSize,
             cleanupResources)
           val leftNullRow = new GenericInternalRow(left.output.length)
@@ -189,7 +189,7 @@ class SortMergeJoinEvaluatorFactory(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
-              spillSizeThreshold,
+              sizeInBytesSpillThreshold,
               spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow)
@@ -227,7 +227,7 @@ class SortMergeJoinEvaluatorFactory(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
-              spillSizeThreshold,
+              sizeInBytesSpillThreshold,
               spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow)
@@ -272,7 +272,7 @@ class SortMergeJoinEvaluatorFactory(
               RowIterator.fromScala(rightIter),
               inMemoryThreshold,
               spillThreshold,
-              spillSizeThreshold,
+              sizeInBytesSpillThreshold,
               spillSize,
               cleanupResources,
               onlyBufferFirstMatchedRow)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 39387ebbb7ee..bc2f9197df9d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -103,7 +103,7 @@ case class SortMergeJoinExec(
     conf.sortMergeJoinExecBufferSpillThreshold
   }
 
-  private def getSpillSizeThreshold: Long = {
+  private def getSizeInBytesSpillThreshold: Long = {
     conf.sortMergeJoinExecBufferSpillSizeThreshold
   }
 
@@ -125,7 +125,7 @@ case class SortMergeJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     val spillSize = longMetric("spillSize")
     val spillThreshold = getSpillThreshold
-    val spillSizeThreshold = getSpillSizeThreshold
+    val sizeInBytesSpillThreshold = getSizeInBytesSpillThreshold
     val inMemoryThreshold = getInMemoryThreshold
     val evaluatorFactory = new SortMergeJoinEvaluatorFactory(
       leftKeys,
@@ -137,7 +137,7 @@ case class SortMergeJoinExec(
       output,
       inMemoryThreshold,
       spillThreshold,
-      spillSizeThreshold,
+      sizeInBytesSpillThreshold,
       numOutputRows,
       spillSize,
       onlyBufferFirstMatchedRow
@@ -228,12 +228,15 @@ case class SortMergeJoinExec(
     val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
 
     val spillThreshold = getSpillThreshold
-    val spillSizeThreshold = getSpillSizeThreshold
+    val sizeInBytesSpillThreshold = getSizeInBytesSpillThreshold
     val inMemoryThreshold = getInMemoryThreshold
 
     // Inline mutable state since not many join operations in a task
     val matches = ctx.addMutableState(clsName, "matches",
-      v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold, 
${spillSizeThreshold}L);",
+      // TODO: shall we have a new config to specify the max in-memory buffer 
size
+      //       of ExternalAppendOnlyUnsafeRowArray?
+      v => s"$v = new $clsName($inMemoryThreshold, 
${sizeInBytesSpillThreshold}L, " +
+        s"$spillThreshold, ${sizeInBytesSpillThreshold}L);",
       forceInline = true)
     // Copy the streamed keys as class members so they could be used in next 
function call.
     val matchedKeyVars = copyKeys(ctx, streamedKeyVars)
@@ -1052,7 +1055,8 @@ case class SortMergeJoinExec(
  * @param inMemoryThreshold Threshold for number of rows guaranteed to be held 
in memory by
  *                          internal buffer
  * @param spillThreshold Threshold for number of rows to be spilled by 
internal buffer
- * @param spillSizeThreshold Threshold for size of rows to be spilled by 
internal buffer
+ * @param sizeInBytesSpillThreshold Threshold for size in bytes of rows to be 
spilled by
+ *                                  internal buffer
  * @param eagerCleanupResources the eager cleanup function to be invoked when 
no join row found
  * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the 
first matching row
  */
@@ -1064,7 +1068,7 @@ private[joins] class SortMergeJoinScanner(
     bufferedIter: RowIterator,
     inMemoryThreshold: Int,
     spillThreshold: Int,
-    spillSizeThreshold: Long,
+    sizeInBytesSpillThreshold: Long,
     spillSize: SQLMetric,
     eagerCleanupResources: () => Unit,
     onlyBufferFirstMatch: Boolean = false) {
@@ -1079,7 +1083,13 @@ private[joins] class SortMergeJoinScanner(
   private[this] var matchJoinKey: InternalRow = _
   /** Buffered rows from the buffered side of the join. This is empty if there 
are no matches. */
   private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
-    new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, 
spillSizeThreshold)
+    new ExternalAppendOnlyUnsafeRowArray(
+      inMemoryThreshold,
+      // TODO: shall we have a new config to specify the max in-memory buffer 
size
+      //       of ExternalAppendOnlyUnsafeRowArray?
+      sizeInBytesSpillThreshold,
+      spillThreshold,
+      sizeInBytesSpillThreshold)
 
   // At the end of the task, update the task's spill size for buffered side.
   TaskContext.get().addTaskCompletionListener[Unit](_ => {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
index d6cc350e485a..a92679054dd8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
@@ -215,10 +215,10 @@ case class ArrowAggregatePythonExec(
       case Some(sessionExpression) =>
         val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
         val spillThreshold = conf.windowExecBufferSpillThreshold
-        val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold
+        val sizeInBytesSpillThreshold = conf.windowExecBufferSpillSizeThreshold
 
         new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, 
sessionExpression,
-          child.output, inMemoryThreshold, spillThreshold, spillSizeThreshold)
+          child.output, inMemoryThreshold, spillThreshold, 
sizeInBytesSpillThreshold)
 
       case None => iter
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 92ed9ff9de45..1643a8d3bdb1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -149,7 +149,7 @@ class ArrowWindowPythonEvaluatorFactory(
 
     private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
     private val spillThreshold = conf.windowExecBufferSpillThreshold
-    private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold
+    private val sizeInBytesSpillThreshold = 
conf.windowExecBufferSpillSizeThreshold
     private val sessionLocalTimeZone = conf.sessionLocalTimeZone
     private val largeVarTypes = conf.arrowUseLargeVarTypes
 
@@ -288,8 +288,13 @@ class ArrowWindowPythonEvaluatorFactory(
 
         // Manage the current partition.
         val buffer: ExternalAppendOnlyUnsafeRowArray =
-          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, 
spillThreshold,
-            spillSizeThreshold)
+          new ExternalAppendOnlyUnsafeRowArray(
+            inMemoryThreshold,
+            // TODO: shall we have a new config to specify the max in-memory 
buffer size
+            //       of ExternalAppendOnlyUnsafeRowArray?
+            sizeInBytesSpillThreshold,
+            spillThreshold,
+            sizeInBytesSpillThreshold)
         var bufferIterator: Iterator[UnsafeRow] = _
 
         val indexRow =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
index d59a0e9f4639..c4b20d4b7c7d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala
@@ -45,7 +45,7 @@ class WindowEvaluatorFactory(
     private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
     private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
     private val spillThreshold = conf.windowExecBufferSpillThreshold
-    private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold
+    private val sizeInBytesSpillThreshold = 
conf.windowExecBufferSpillSizeThreshold
 
     override def eval(
         partitionIndex: Int,
@@ -83,8 +83,13 @@ class WindowEvaluatorFactory(
 
         // Manage the current partition.
         val buffer: ExternalAppendOnlyUnsafeRowArray =
-          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, 
spillThreshold,
-            spillSizeThreshold)
+          new ExternalAppendOnlyUnsafeRowArray(
+            inMemoryThreshold,
+            // TODO: shall we have a new config to specify the max in-memory 
buffer size
+            //       of ExternalAppendOnlyUnsafeRowArray?
+            sizeInBytesSpillThreshold,
+            spillThreshold,
+            sizeInBytesSpillThreshold)
 
         var bufferIterator: Iterator[UnsafeRow] = _
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
index 461c899325f4..124e15397ca5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -107,6 +107,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends 
BenchmarkBase {
       for (_ <- 0L until iterations) {
         val array = new ExternalAppendOnlyUnsafeRowArray(
           ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+          Long.MaxValue,
           numSpillThreshold,
           Long.MaxValue)
 
@@ -172,7 +173,9 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends 
BenchmarkBase {
     benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
       var sum = 0L
       for (_ <- 0L until iterations) {
-        val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold,
+        val array = new ExternalAppendOnlyUnsafeRowArray(
+          numSpillThreshold,
+          Long.MaxValue,
           numSpillThreshold,
           Long.MaxValue)
         rows.foreach(x => array.add(x))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
index 62ea7f2f9259..e667a95269f4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -47,6 +47,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends 
SparkFunSuite with LocalSpar
         1024,
         SparkEnv.get.memoryManager.pageSizeBytes,
         inMemoryThreshold,
+        Long.MaxValue,
         spillThreshold,
         Long.MaxValue)
       try f(array) finally {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to