This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch dev-v6.0.0-decimal-cast
in repository https://gitbox.apache.org/repos/asf/auron.git

commit 92216fc4441bfae24fe7c0f89761338f916d9ab4
Author: zhangli20 <[email protected]>
AuthorDate: Thu Dec 25 21:23:28 2025 +0800

    fix possible deadlock in OnHeapSpillManager
---
 .../spark/sql/blaze/memory/OnHeapSpill.scala       | 38 ++++++++++++++++++----
 .../sql/blaze/memory/OnHeapSpillManager.scala      |  2 +-
 2 files changed, 32 insertions(+), 8 deletions(-)

diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpill.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpill.scala
index 137f67c4..3e6dac66 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpill.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpill.scala
@@ -16,20 +16,32 @@
 package org.apache.spark.sql.blaze.memory
 
 import java.nio.ByteBuffer
-
 import org.apache.spark.internal.Logging
+import org.apache.spark.memory.MemoryConsumer
+
+import java.util.concurrent.locks.ReentrantLock
 
 case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) extends Logging {
   private var spillBuf: SpillBuf = new MemBasedSpillBuf
+  private val lock = new ReentrantLock
 
   def memUsed: Long = spillBuf.memUsed
   def diskUsed: Long = spillBuf.diskUsed
   def size: Long = spillBuf.size
   def diskIOTime: Long = spillBuf.diskIOTime
 
+  private def withLock[T](f: => T): T = {
+    lock.lock()
+    try {
+      f
+    } finally {
+      lock.unlock()
+    }
+  }
+
   def write(buf: ByteBuffer): Unit = {
     var needSpill = false
-    synchronized {
+    withLock {
       spillBuf match {
         case _: MemBasedSpillBuf =>
           val acquiredMemory = hsm.acquireMemory(buf.capacity())
@@ -45,13 +57,13 @@ case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) 
extends Logging {
       spillInternal()
     }
 
-    synchronized {
+    withLock {
       spillBuf.write(buf)
     }
   }
 
   def read(buf: ByteBuffer): Int = {
-    synchronized {
+    withLock {
       val oldMemUsed = memUsed
       val startPosition = buf.position()
       spillBuf.read(buf)
@@ -68,7 +80,7 @@ case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) 
extends Logging {
   }
 
   def release(): Unit = {
-    synchronized {
+    withLock {
       val oldMemUsed = memUsed
       spillBuf = new ReleasedSpillBuf(spillBuf)
 
@@ -78,8 +90,20 @@ case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) 
extends Logging {
     }
   }
 
-  def spill(): Long = {
-    synchronized {
+  def spill(trigger: MemoryConsumer): Long = {
+    // this might have been locked if the spilling is triggered by 
OnHeapSpill.write
+    if (trigger == this.hsm) {
+      if (lock.tryLock()) {
+        try {
+          return spillInternal()
+        } finally {
+          lock.unlock()
+        }
+      }
+      return 0L
+    }
+
+    withLock {
       spillInternal()
     }
   }
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala
index 5f89d710..81d105ab 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala
@@ -158,7 +158,7 @@ class OnHeapSpillManager(taskContext: TaskContext)
         val sortedSpills = spills.seq.sortBy(0 - 
_.map(_.memUsed).getOrElse(0L))
         sortedSpills.foreach {
           case Some(spill) if spill.memUsed > 0 =>
-            totalFreed += spill.spill()
+            totalFreed += spill.spill(trigger)
             if (totalFreed >= size) {
               return totalFreed
             }

Reply via email to