This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new cdd077a39fd9 [SPARK-47819][CONNECT][3.5] Use asynchronous callback for execution cleanup cdd077a39fd9 is described below commit cdd077a39fd99ba7c2fba4e89f6ef9668cf3cbce Author: Xi Lyu <xi....@databricks.com> AuthorDate: Wed Apr 24 09:08:59 2024 -0400 [SPARK-47819][CONNECT][3.5] Use asynchronous callback for execution cleanup ([Original PR](https://github.com/apache/spark/pull/46027)) ### What changes were proposed in this pull request? Expired sessions are regularly checked and cleaned up by a maintenance thread. However, currently, this process is synchronous. Therefore, in rare cases, interrupting the execution thread of a query in a session can take hours, causing the entire maintenance process to stall, resulting in a large amount of memory not being cleared. We address this by introducing asynchronous callbacks for execution cleanup, avoiding synchronous joins of execution threads, and preventing the maintenance thread from stalling in the above scenarios. To be more specific, instead of calling `runner.join()` in `ExecutorHolder.close()`, we set a post-cleanup function as the callback through `runner.processOnCompletion`, which will be called asynchronously once the execution runner is completed or interrupted. In this way, the maintenan [...] ### Why are the changes needed? In the rare cases mentioned above, performance can be severely affected. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests and a new test `Async cleanup callback gets called after the execution is closed` in `ReattachableExecuteSuite.scala`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46064 from xi-db/SPARK-47819-async-cleanup-3.5. Authored-by: Xi Lyu <xi....@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../connect/execution/ExecuteThreadRunner.scala | 31 +++++++++++++++++----- .../spark/sql/connect/service/ExecuteHolder.scala | 23 ++++++++-------- .../execution/ReattachableExecuteSuite.scala | 22 +++++++++++++++ .../connect/planner/SparkConnectServiceSuite.scala | 7 ++++- 4 files changed, 64 insertions(+), 19 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 62083d4892f7..d503dde3d18c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.execution +import scala.concurrent.{ExecutionContext, Promise} +import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -29,7 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -37,10 +39,12 @@ import org.apache.spark.util.Utils */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + private val promise: Promise[Unit] = Promise[Unit]() + // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private var executionThread: Thread = new ExecutionThread() + private val executionThread: Thread = new ExecutionThread(promise) private var interrupted: Boolean = false @@ -53,9 +57,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends executionThread.start() } - /** Joins the background execution thread after it is finished. */ - def join(): Unit = { - executionThread.join() + /** + * Register a callback that gets executed after completion/interruption of the execution + */ + private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { + promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) } /** @@ -222,10 +228,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread + private class ExecutionThread(onCompletionPromise: Promise[Unit]) extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { - execute() + try { + execute() + onCompletionPromise.success(()) + } catch { + case NonFatal(e) => + onCompletionPromise.failure(e) + } } } } + +private[connect] object ExecuteThreadRunner { + private implicit val namedExecutionContext: ExecutionContext = ExecutionContext + .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 974c13b08e31..5cf63c2195ab 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -114,6 +114,9 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() + /** For testing. Whether the async completion callback is called. */ + @volatile private[connect] var completionCallbackCalled: Boolean = false + /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the @@ -125,13 +128,6 @@ private[connect] class ExecuteHolder( runner.start() } - /** - * Wait for the execution thread to finish and join it. - */ - def join(): Unit = { - runner.join() - } - /** * Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them * out on the Grpc response stream. The sender will start from the start of the response stream. @@ -234,8 +230,15 @@ private[connect] class ExecuteHolder( if (closedTime.isEmpty) { // interrupt execution, if still running. runner.interrupt() - // wait for execution to finish, to make sure no more results get pushed to responseObserver - runner.join() + // Do not wait for the execution to finish, clean up resources immediately. + runner.processOnCompletion { _ => + completionCallbackCalled = true + // The execution may not immediately get interrupted, clean up any remaining resources when + // it does. + responseObserver.removeAll() + // post closed to UI + eventsManager.postClosed() + } // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -245,8 +248,6 @@ private[connect] class ExecuteHolder( } // remove all cached responses from observer responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() closedTime = Some(System.currentTimeMillis()) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 0e29a07b719a..06cd1a5666b6 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -355,4 +355,26 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { assertEventuallyNoActiveExecutions() } } + + test("Async cleanup callback gets called after the execution is closed") { + withClient { client => + val query1 = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + // just creating the iterator is lazy, trigger query1 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close execution + SparkConnectService.executionManager.removeExecuteHolder(executeHolder1.key) + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 90c9d13def61..06508bfc6a7c 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.mockito.Mockito.when import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkEnv} @@ -879,8 +881,11 @@ class SparkConnectServiceSuite assert(executeHolder.eventsManager.hasError.isDefined) } def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org