This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bc9573e94f2 [SPARK-47819][CONNECT] Use asynchronous callback for 
execution cleanup
2bc9573e94f2 is described below

commit 2bc9573e94f29cd5394429b623e30c4386a473ba
Author: Xi Lyu <xi....@databricks.com>
AuthorDate: Fri Apr 12 08:48:40 2024 -0400

    [SPARK-47819][CONNECT] Use asynchronous callback for execution cleanup
    
    ### What changes were proposed in this pull request?
    
    Expired sessions are regularly checked and cleaned up by a maintenance 
thread. However, currently, this process is synchronous. Therefore, in rare 
cases, interrupting the execution thread of a query in a session can take 
hours, causing the entire maintenance process to stall, resulting in a large 
amount of memory not being cleared.
    
    We address this by introducing asynchronous callbacks for execution 
cleanup, avoiding synchronous joins of execution threads, and preventing the 
maintenance thread from stalling in the above scenarios. To be more specific, 
instead of calling `runner.join()` in `ExecutorHolder.close()`, we set a 
post-cleanup function as the callback through `runner.processOnCompletion`, 
which will be called asynchronously once the execution runner is completed or 
interrupted. In this way, the maintenan [...]
    
    ### Why are the changes needed?
    
    In the rare cases mentioned above, performance can be severely affected.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests and a new test `Async cleanup callback gets called after the 
execution is closed` in `SparkConnectServiceE2ESuite.scala`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46027 from xi-db/SPARK-47819-async-cleanup.
    
    Authored-by: Xi Lyu <xi....@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../connect/execution/ExecuteThreadRunner.scala    | 33 ++++++++++++++++------
 .../spark/sql/connect/service/ExecuteHolder.scala  | 16 ++++++++---
 .../connect/planner/SparkConnectServiceSuite.scala |  7 ++++-
 .../service/SparkConnectServiceE2ESuite.scala      | 23 +++++++++++++++
 4 files changed, 65 insertions(+), 14 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 56776819dac9..37c3120a8ff4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.connect.execution
 
+import scala.concurrent.{ExecutionContext, Promise}
 import scala.jdk.CollectionConverters._
+import scala.util.Try
 import scala.util.control.NonFatal
 
 import com.google.protobuf.Message
@@ -30,7 +32,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 /**
  * This class launches the actual execution in an execution thread. The 
execution pushes the
@@ -38,10 +40,12 @@ import org.apache.spark.util.Utils
  */
 private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) 
extends Logging {
 
+  private val promise: Promise[Unit] = Promise[Unit]()
+
   // The newly created thread will inherit all InheritableThreadLocals used by 
Spark,
   // e.g. SparkContext.localProperties. If considering implementing a 
thread-pool,
   // forwarding of thread locals needs to be taken into account.
-  private val executionThread: Thread = new ExecutionThread()
+  private val executionThread: ExecutionThread = new ExecutionThread(promise)
 
   private var started: Boolean = false
 
@@ -63,11 +67,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
     }
   }
 
-  /** Joins the background execution thread after it is finished. */
-  private[connect] def join(): Unit = {
-    // only called when the execution is completed or interrupted.
-    assert(completed || interrupted)
-    executionThread.join()
+  /**
+   * Register a callback that gets executed after completion/interruption of 
the execution thread.
+   */
+  private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit 
= {
+    
promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext)
   }
 
   /**
@@ -276,10 +280,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
       .build()
   }
 
-  private class ExecutionThread
+  private class ExecutionThread(onCompletionPromise: Promise[Unit])
       extends 
Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") {
     override def run(): Unit = {
-      execute()
+      try {
+        execute()
+        onCompletionPromise.success(())
+      } catch {
+        case NonFatal(e) =>
+          onCompletionPromise.failure(e)
+      }
     }
   }
 }
+
+private[connect] object ExecuteThreadRunner {
+  private implicit val namedExecutionContext: ExecutionContext = 
ExecutionContext
+    
.fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback"))
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index f03f81326064..3112d12bb0e6 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -117,6 +117,9 @@ private[connect] class ExecuteHolder(
       : 
mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] =
     new 
mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]()
 
+  /** For testing. Whether the async completion callback is called. */
+  @volatile private[connect] var completionCallbackCalled: Boolean = false
+
   /**
    * Start the execution. The execution is started in a background thread in 
ExecuteThreadRunner.
    * Responses are produced and cached in ExecuteResponseObserver. A GRPC 
thread consumes the
@@ -238,8 +241,15 @@ private[connect] class ExecuteHolder(
     if (closedTimeMs.isEmpty) {
       // interrupt execution, if still running.
       runner.interrupt()
-      // wait for execution to finish, to make sure no more results get pushed 
to responseObserver
-      runner.join()
+      // Do not wait for the execution to finish, clean up resources 
immediately.
+      runner.processOnCompletion { _ =>
+        completionCallbackCalled = true
+        // The execution may not immediately get interrupted, clean up any 
remaining resources when
+        // it does.
+        responseObserver.removeAll()
+        // post closed to UI
+        eventsManager.postClosed()
+      }
       // interrupt any attached grpcResponseSenders
       grpcResponseSenders.foreach(_.interrupt())
       // if there were still any grpcResponseSenders, register detach time
@@ -249,8 +259,6 @@ private[connect] class ExecuteHolder(
       }
       // remove all cached responses from observer
       responseObserver.removeAll()
-      // post closed to UI
-      eventsManager.postClosed()
       closedTimeMs = Some(System.currentTimeMillis())
     }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 63cebd452364..af18fca9dd21 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector}
 import org.apache.arrow.vector.ipc.ArrowStreamReader
 import org.mockito.Mockito.when
 import org.scalatest.Tag
+import org.scalatest.concurrent.Eventually
+import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
 import org.scalatestplus.mockito.MockitoSugar
 
 import org.apache.spark.{SparkContext, SparkEnv}
@@ -884,8 +886,11 @@ class SparkConnectServiceSuite
       assert(executeHolder.eventsManager.hasError.isDefined)
     }
     def onCompleted(producedRowCount: Option[Long] = None): Unit = {
-      assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
       assert(executeHolder.eventsManager.getProducedRowCount == 
producedRowCount)
+      // The eventsManager is closed asynchronously
+      Eventually.eventually(timeout(1.seconds)) {
+        assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
+      }
     }
     def onCanceled(): Unit = {
       assert(executeHolder.eventsManager.hasCanceled.contains(true))
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 33560cd53f6b..cb0bd8f771eb 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -91,6 +91,29 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
     }
   }
 
+  test("Async cleanup callback gets called after the execution is closed") {
+    withClient(UUID.randomUUID().toString, defaultUserId) { client =>
+      val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+      // just creating the iterator is lazy, trigger query1 and query2 to be 
sent.
+      query1.hasNext
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 1)
+      }
+      val executeHolder1 = 
SparkConnectService.executionManager.listExecuteHolders.head
+      // Close session
+      client.releaseSession()
+      // Check that queries get cancelled
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 0)
+        // SparkConnectService.sessionManager.
+      }
+      // Check the async execute cleanup get called
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(executeHolder1.completionCallbackCalled)
+      }
+    }
+  }
+
   private def testReleaseSessionTwoSessions(
       sessionIdA: String,
       userIdA: String,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to