This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 418bba5ad60 [SPARK-44709][CONNECT] Run ExecuteGrpcResponseSender in 
reattachable execute in new thread to fix flow control
418bba5ad60 is described below

commit 418bba5ad6053449a141f3c9c31ed3ad998995b8
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Tue Aug 8 18:32:25 2023 +0200

    [SPARK-44709][CONNECT] Run ExecuteGrpcResponseSender in reattachable 
execute in new thread to fix flow control
    
    ### What changes were proposed in this pull request?
    
    If executePlan / reattachExecute handling is done directly on the GRPC 
thread, flow control OnReady events are getting queued until after the handler 
returns, so OnReadyHandler never gets notified until after the handler exits.
    The correct way to use it is for the handler to delegate work to another 
thread and exit. See https://github.com/grpc/grpc-java/issues/7361
    
    Tidied up and added a lot of logging and statistics to 
ExecuteGrpcResponseSender and ExecuteResponseObserver to be able to observe 
this behaviour.
    
    Followup work in https://issues.apache.org/jira/browse/SPARK-44625 is 
needed for cleanup of abandoned executions that will also make sure that these 
threads are joined.
    
    ### Why are the changes needed?
    
    ExecuteGrpcResponseSender gets stuck waiting on grpcCallObserverReadySignal 
because events from OnReadyHandler do not arrive.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added extensive debugging to ExecuteGrpcResponseSender and 
ExecuteResponseObserver and tested and observer the behaviour of all the 
threads.
    
    Closes #42355 from juliuszsompolski/spark-rpc-extra-thread.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../apache/spark/sql/connect/config/Connect.scala  |  13 +-
 .../connect/execution/CachedStreamResponse.scala   |   2 +
 .../execution/ExecuteGrpcResponseSender.scala      | 164 +++++++++++++++------
 .../execution/ExecuteResponseObserver.scala        | 116 ++++++++++++---
 .../connect/execution/ExecuteThreadRunner.scala    |   3 +-
 .../spark/sql/connect/service/ExecuteHolder.scala  |  21 ++-
 .../service/SparkConnectExecutePlanHandler.scala   |  20 +--
 .../SparkConnectReattachExecuteHandler.scala       |  22 +--
 8 files changed, 264 insertions(+), 97 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index e25cb5cbab2..0be53064cc0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -74,6 +74,17 @@ object Connect {
       .intConf
       .createWithDefault(1024)
 
+  val CONNECT_EXECUTE_REATTACHABLE_ENABLED =
+    ConfigBuilder("spark.connect.execute.reattachable.enabled")
+      .internal()
+      .doc("Enables reattachable execution on the server. If disabled and a 
client requests it, " +
+        "non-reattachable execution will follow and should run until query 
completion. This will " +
+        "work, unless there is a GRPC stream error, in which case the client 
will discover that " +
+        "execution is not reattachable when trying to reattach fails.")
+      .version("3.5.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION =
     ConfigBuilder("spark.connect.execute.reattachable.senderMaxStreamDuration")
       .internal()
@@ -82,7 +93,7 @@ object Connect {
         "Set to 0 for unlimited.")
       .version("3.5.0")
       .timeConf(TimeUnit.MILLISECONDS)
-      .createWithDefaultString("5m")
+      .createWithDefaultString("2m")
 
   val CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE =
     ConfigBuilder("spark.connect.execute.reattachable.senderMaxStreamSize")
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala
index ec9fce785ba..a2bbe14f201 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala
@@ -22,6 +22,8 @@ import com.google.protobuf.MessageLite
 private[execution] case class CachedStreamResponse[T <: MessageLite](
     // the actual cached response
     response: T,
+    // the id of the response, an UUID.
+    responseId: String,
     // index of the response in the response stream.
     // responses produced in the stream are numbered consecutively starting 
from 1.
     streamIndex: Long) {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
index 88124080cca..7b51a90ca37 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
@@ -17,24 +17,24 @@
 
 package org.apache.spark.sql.connect.execution
 
-import com.google.protobuf.MessageLite
+import com.google.protobuf.Message
 import io.grpc.stub.{ServerCallStreamObserver, StreamObserver}
 
 import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.common.ProtoUtils
 import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION,
 CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE}
 import org.apache.spark.sql.connect.service.ExecuteHolder
 
 /**
- * ExecuteGrpcResponseSender sends responses to the GRPC stream. It runs on 
the RPC thread, and
- * gets notified by ExecuteResponseObserver about available responses. It 
notifies the
- * ExecuteResponseObserver back about cached responses that can be removed 
after being sent out.
+ * ExecuteGrpcResponseSender sends responses to the GRPC stream. It consumes 
responses from
+ * ExecuteResponseObserver and sends them out as responses to ExecutePlan or 
ReattachExecute.
  * @param executeHolder
  *   The execution this sender attaches to.
  * @param grpcObserver
  *   the GRPC request StreamObserver
  */
-private[connect] class ExecuteGrpcResponseSender[T <: MessageLite](
+private[connect] class ExecuteGrpcResponseSender[T <: Message](
     val executeHolder: ExecuteHolder,
     grpcObserver: StreamObserver[T])
     extends Logging {
@@ -47,15 +47,69 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
   // Signal to wake up when grpcCallObserver.isReady()
   private val grpcCallObserverReadySignal = new Object
 
+  // Stats
+  private var consumeSleep = 0L
+  private var sendSleep = 0L
+
   /**
    * Detach this sender from executionObserver. Called only from 
executionObserver that this
-   * sender is attached to. executionObserver holds lock, and needs to notify 
after this call.
+   * sender is attached to. Lock on executionObserver is held, and notifyAll 
will wake up this
+   * sender if sleeping.
    */
-  def detach(): Unit = {
+  def detach(): Unit = executionObserver.synchronized {
     if (detached == true) {
       throw new IllegalStateException("ExecuteGrpcResponseSender already 
detached!")
     }
     detached = true
+    executionObserver.notifyAll()
+  }
+
+  def run(lastConsumedStreamIndex: Long): Unit = {
+    if (executeHolder.reattachable) {
+      // In reattachable execution, check if grpcObserver is ready for 
sending, by using
+      // setOnReadyHandler of the ServerCallStreamObserver. Otherwise, calling 
grpcObserver.onNext
+      // can queue the responses without sending them, and it is unknown how 
far behind it is, and
+      // hence how much the executionObserver needs to buffer.
+      //
+      // Because OnReady events get queued on the same GRPC inboud queue as 
the executePlan or
+      // reattachExecute RPC handler that this is executing in, OnReady events 
will not arrive and
+      // not trigger the OnReadyHandler unless this thread returns from 
executePlan/reattachExecute.
+      // Therefore, we launch another thread to operate on the grpcObserver 
and send the responses,
+      // while this thread will exit from the executePlan/reattachExecute 
call, allowing GRPC
+      // to send the OnReady events.
+      // See https://github.com/grpc/grpc-java/issues/7361
+
+      val t = new Thread(
+        s"SparkConnectGRPCSender_" +
+          
s"opId=${executeHolder.operationId}_startIndex=$lastConsumedStreamIndex") {
+        override def run(): Unit = {
+          execute(lastConsumedStreamIndex)
+        }
+      }
+      executeHolder.grpcSenderThreads += t
+
+      val grpcCallObserver = 
grpcObserver.asInstanceOf[ServerCallStreamObserver[T]]
+      grpcCallObserver.setOnReadyHandler(() => {
+        logTrace(s"Stream ready, notify grpcCallObserverReadySignal.")
+        grpcCallObserverReadySignal.synchronized {
+          grpcCallObserverReadySignal.notifyAll()
+        }
+      })
+
+      // Start the thread and exit
+      t.start()
+    } else {
+      // Non reattachable execute runs directly in the GRPC thread.
+      try {
+        execute(lastConsumedStreamIndex)
+      } finally {
+        if (!executeHolder.reattachable) {
+          // Non reattachable executions release here immediately.
+          // (Reattachable executions release with ReleaseExecute RPC.)
+          executeHolder.close()
+        }
+      }
+    }
   }
 
   /**
@@ -70,26 +124,16 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
    *   the last index that was already consumed and sent. This sender will 
start from index after
    *   that. 0 means start from beginning (since first response has index 1)
    */
-  def run(lastConsumedStreamIndex: Long): Unit = {
-    logDebug(
-      s"GrpcResponseSender run for $executeHolder, " +
+  def execute(lastConsumedStreamIndex: Long): Unit = {
+    logInfo(
+      s"Starting for opId=${executeHolder.operationId}, " +
         s"reattachable=${executeHolder.reattachable}, " +
         s"lastConsumedStreamIndex=$lastConsumedStreamIndex")
+    val startTime = System.nanoTime()
 
     // register to be notified about available responses.
     executionObserver.attachConsumer(this)
 
-    // In reattachable execution, we check if grpcCallObserver is ready for 
sending.
-    // See sendResponse
-    if (executeHolder.reattachable) {
-      val grpcCallObserver = 
grpcObserver.asInstanceOf[ServerCallStreamObserver[T]]
-      grpcCallObserver.setOnReadyHandler(() => {
-        grpcCallObserverReadySignal.synchronized {
-          grpcCallObserverReadySignal.notifyAll()
-        }
-      })
-    }
-
     var nextIndex = lastConsumedStreamIndex + 1
     var finished = false
 
@@ -129,30 +173,38 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
       // Get next available response.
       // Wait until either this sender got detached or next response is ready,
       // or the stream is complete and it had already sent all responses.
-      logDebug(s"Trying to get next response with index=$nextIndex.")
+      logTrace(s"Trying to get next response with index=$nextIndex.")
       executionObserver.synchronized {
-        logDebug(s"Acquired executionObserver lock.")
+        logTrace(s"Acquired executionObserver lock.")
+        val sleepStart = System.nanoTime()
+        var sleepEnd = 0L
         while (!detachedFromObserver &&
           !gotResponse &&
           !streamFinished &&
           !deadlineLimitReached) {
-          logDebug(s"Try to get response with index=$nextIndex from observer.")
+          logTrace(s"Try to get response with index=$nextIndex from observer.")
           response = executionObserver.consumeResponse(nextIndex)
-          logDebug(s"Response index=$nextIndex from observer: 
${response.isDefined}")
+          logTrace(s"Response index=$nextIndex from observer: 
${response.isDefined}")
           // If response is empty, release executionObserver lock and wait to 
get notified.
           // The state of detached, response and lastIndex are change under 
lock in
           // executionObserver, and will notify upon state change.
           if (response.isEmpty) {
             val timeout = Math.max(1, deadlineTimeMillis - 
System.currentTimeMillis())
-            logDebug(s"Wait for response to become available with 
timeout=$timeout ms.")
+            logTrace(s"Wait for response to become available with 
timeout=$timeout ms.")
             executionObserver.wait(timeout)
-            logDebug(s"Reacquired executionObserver lock after waiting.")
+            logTrace(s"Reacquired executionObserver lock after waiting.")
+            sleepEnd = System.nanoTime()
           }
         }
-        logDebug(
-          s"Exiting loop: detached=$detached, response=$response, " +
+        logTrace(
+          s"Exiting loop: detached=$detached, " +
+            s"response=${response.map(r => 
ProtoUtils.abbreviate(r.response))}, " +
             s"lastIndex=${executionObserver.getLastResponseIndex()}, " +
             s"deadline=${deadlineLimitReached}")
+        if (sleepEnd > 0) {
+          consumeSleep += sleepEnd - sleepStart
+          logTrace(s"Slept waiting for execution stream for ${sleepEnd - 
sleepStart}ns.")
+        }
       }
 
       // Process the outcome of the inner loop.
@@ -160,24 +212,30 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
         // This sender got detached by the observer.
         // This only happens if this RPC is actually dead, and the client 
already came back with
         // a ReattachExecute RPC. Kill this RPC.
-        logDebug(s"Detached from observer at index ${nextIndex - 1}. Complete 
stream.")
+        logWarning(
+          s"Got detached from opId=${executeHolder.operationId} at index 
${nextIndex - 1}." +
+            s"totalTime=${System.nanoTime - startTime}ns " +
+            s"waitingForResults=${consumeSleep}ns 
waitingForSend=${sendSleep}ns")
         throw new SparkSQLException(errorClass = 
"INVALID_CURSOR.DISCONNECTED", Map.empty)
       } else if (gotResponse) {
         // There is a response available to be sent.
-        val sent = sendResponse(response.get.response, deadlineTimeMillis)
+        val sent = sendResponse(response.get, deadlineTimeMillis)
         if (sent) {
-          logDebug(s"Sent response index=$nextIndex.")
           sentResponsesSize += response.get.serializedByteSize
           nextIndex += 1
           assert(finished == false)
         } else {
-          // If it wasn't sent, time deadline must have been reached before 
stream became available.
+          // If it wasn't sent, time deadline must have been reached before 
stream became available,
+          // will exit in the enxt loop iterattion.
           assert(deadlineLimitReached)
-          finished = true
         }
       } else if (streamFinished) {
         // Stream is finished and all responses have been sent
-        logDebug(s"Stream finished and sent all responses up to index 
${nextIndex - 1}.")
+        logInfo(
+          s"Stream finished for opId=${executeHolder.operationId}, " +
+            s"sent all responses up to last index ${nextIndex - 1}. " +
+            s"totalTime=${System.nanoTime - startTime}ns " +
+            s"waitingForResults=${consumeSleep}ns 
waitingForSend=${sendSleep}ns")
         executionObserver.getError() match {
           case Some(t) => grpcObserver.onError(t)
           case None => grpcObserver.onCompleted()
@@ -186,7 +244,11 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
       } else if (deadlineLimitReached) {
         // The stream is not complete, but should be finished now.
         // The client needs to reattach with ReattachExecute.
-        logDebug(s"Deadline reached, finishing stream after index ${nextIndex 
- 1}.")
+        logInfo(
+          s"Deadline reached, shutting down stream for 
opId=${executeHolder.operationId} " +
+            s"after index ${nextIndex - 1}. " +
+            s"totalTime=${System.nanoTime - startTime}ns " +
+            s"waitingForResults=${consumeSleep}ns 
waitingForSend=${sendSleep}ns")
         grpcObserver.onCompleted()
         finished = true
       }
@@ -205,10 +267,15 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
    * @return
    *   true if the response was sent, false otherwise (meaning deadline passed)
    */
-  private def sendResponse(response: T, deadlineTimeMillis: Long): Boolean = {
+  private def sendResponse(
+      response: CachedStreamResponse[T],
+      deadlineTimeMillis: Long): Boolean = {
     if (!executeHolder.reattachable) {
       // no flow control in non-reattachable execute
-      grpcObserver.onNext(response)
+      logDebug(
+        s"SEND opId=${executeHolder.operationId} 
responseId=${response.responseId} " +
+          s"idx=${response.streamIndex} (no flow control)")
+      grpcObserver.onNext(response.response)
       true
     } else {
       // In reattachable execution, we control the flow, and only pass the 
response to the
@@ -225,19 +292,28 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
MessageLite](
       val grpcCallObserver = 
grpcObserver.asInstanceOf[ServerCallStreamObserver[T]]
 
       grpcCallObserverReadySignal.synchronized {
-        logDebug(s"Acquired grpcCallObserverReadySignal lock.")
+        logTrace(s"Acquired grpcCallObserverReadySignal lock.")
+        val sleepStart = System.nanoTime()
+        var sleepEnd = 0L
         while (!grpcCallObserver.isReady() && deadlineTimeMillis >= 
System.currentTimeMillis()) {
           val timeout = Math.max(1, deadlineTimeMillis - 
System.currentTimeMillis())
-          logDebug(s"Wait for grpcCallObserver to become ready with 
timeout=$timeout ms.")
+          var sleepStart = System.nanoTime()
+          logTrace(s"Wait for grpcCallObserver to become ready with 
timeout=$timeout ms.")
           grpcCallObserverReadySignal.wait(timeout)
-          logDebug(s"Reacquired grpcCallObserverReadySignal lock after 
waiting.")
+          logTrace(s"Reacquired grpcCallObserverReadySignal lock after 
waiting.")
+          sleepEnd = System.nanoTime()
         }
         if (grpcCallObserver.isReady()) {
-          logDebug(s"grpcCallObserver is ready, sending response.")
-          grpcCallObserver.onNext(response)
+          val sleepTime = if (sleepEnd > 0L) sleepEnd - sleepStart else 0L
+          logDebug(
+            s"SEND opId=${executeHolder.operationId} 
responseId=${response.responseId} " +
+              s"idx=${response.streamIndex}" +
+              s"(waiting ${sleepTime}ns for GRPC stream to be ready)")
+          sendSleep += sleepTime
+          grpcCallObserver.onNext(response.response)
           true
         } else {
-          logDebug(s"grpcCallObserver is not ready, exiting.")
+          logTrace(s"grpcCallObserver is not ready, exiting.")
           false
         }
       }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
index 8af0f72b8da..0573f7b3dae 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
@@ -21,13 +21,13 @@ import java.util.UUID
 
 import scala.collection.mutable
 
-import com.google.protobuf.MessageLite
+import com.google.protobuf.Message
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
-import 
org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE
+import 
org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE
 import org.apache.spark.sql.connect.service.ExecuteHolder
 
 /**
@@ -47,7 +47,7 @@ import org.apache.spark.sql.connect.service.ExecuteHolder
  * @see
  *   attachConsumer
  */
-private[connect] class ExecuteResponseObserver[T <: MessageLite](val 
executeHolder: ExecuteHolder)
+private[connect] class ExecuteResponseObserver[T <: Message](val 
executeHolder: ExecuteHolder)
     extends StreamObserver[T]
     with Logging {
 
@@ -85,12 +85,18 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
    */
   private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None
 
+  // Statistics about cached responses.
+  private var cachedSizeUntilHighestConsumed = CachedSize()
+  private var cachedSizeUntilLastProduced = CachedSize()
+  private var autoRemovedSize = CachedSize()
+  private var totalSize = CachedSize()
+
   /**
    * Total size of response to be held buffered after giving out with 
getResponse. 0 for none, any
    * value greater than 0 will buffer the response from getResponse.
    */
   private val retryBufferSize = if (executeHolder.reattachable) {
-    
SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE).toLong
+    
SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE).toLong
   } else {
     0
   }
@@ -101,11 +107,19 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
     }
     lastProducedIndex += 1
     val processedResponse = setCommonResponseFields(r)
-    responses +=
-      ((lastProducedIndex, CachedStreamResponse[T](processedResponse, 
lastProducedIndex)))
-    responseIndexToId += ((lastProducedIndex, 
getResponseId(processedResponse)))
-    responseIdToIndex += ((getResponseId(processedResponse), 
lastProducedIndex))
-    logDebug(s"Saved response with index=$lastProducedIndex")
+    val responseId = getResponseId(processedResponse)
+    val response = CachedStreamResponse[T](processedResponse, responseId, 
lastProducedIndex)
+
+    responses += ((lastProducedIndex, response))
+    responseIndexToId += ((lastProducedIndex, responseId))
+    responseIdToIndex += ((responseId, lastProducedIndex))
+
+    cachedSizeUntilLastProduced.add(response)
+    totalSize.add(response)
+
+    logDebug(
+      s"Execution opId=${executeHolder.operationId} produced response " +
+        s"responseId=${responseId} idx=$lastProducedIndex")
     notifyAll()
   }
 
@@ -115,7 +129,9 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
     }
     error = Some(t)
     finalProducedIndex = Some(lastProducedIndex) // no responses to be send 
after error.
-    logDebug(s"Error. Last stream index is $lastProducedIndex.")
+    logDebug(
+      s"Execution opId=${executeHolder.operationId} produced error. " +
+        s"Last stream index is $lastProducedIndex.")
     notifyAll()
   }
 
@@ -124,18 +140,17 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
       throw new IllegalStateException("Stream onCompleted can't be called 
after stream completed")
     }
     finalProducedIndex = Some(lastProducedIndex)
-    logDebug(s"Completed. Last stream index is $lastProducedIndex.")
+    logDebug(
+      s"Execution opId=${executeHolder.operationId} completed stream. " +
+        s"Last stream index is $lastProducedIndex.")
     notifyAll()
   }
 
   /** Attach a new consumer (ExecuteResponseGRPCSender). */
   def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = 
synchronized {
     // detach the current sender before attaching new one
-    // this.synchronized() needs to be held while detaching a sender, and the 
detached sender
-    // needs to be notified with notifyAll() afterwards.
     responseSender.foreach(_.detach())
     responseSender = Some(newSender)
-    notifyAll() // consumer
   }
 
   /**
@@ -150,9 +165,18 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
     assert(index <= highestConsumedIndex + 1)
     val ret = responses.get(index)
     if (ret.isDefined) {
-      if (index > highestConsumedIndex) highestConsumedIndex = index
+      if (index > highestConsumedIndex) {
+        highestConsumedIndex = index
+        cachedSizeUntilHighestConsumed.add(ret.get)
+      }
       // When the response is consumed, figure what previous responses can be 
uncached.
-      removeCachedResponses(index)
+      // (We keep at least one response before the one we send to consumer now)
+      removeCachedResponses(index - 1)
+      logDebug(
+        s"CONSUME opId=${executeHolder.operationId} 
responseId=${ret.get.responseId} " +
+          s"idx=$index. size=${ret.get.serializedByteSize} " +
+          s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " +
+          s"cachedUntilProduced=$cachedSizeUntilLastProduced")
     } else if (index <= highestConsumedIndex) {
       // If index is <= highestConsumedIndex and not available, it was already 
removed from cache.
       // This may happen if ReattachExecute is too late and the cached 
response was evicted.
@@ -191,6 +215,25 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
   def removeResponsesUntilId(responseId: String): Unit = synchronized {
     val index = getResponseIndexById(responseId)
     removeResponsesUntilIndex(index)
+    logDebug(
+      s"RELEASE opId=${executeHolder.operationId} until " +
+        s"responseId=$responseId " +
+        s"idx=$index. " +
+        s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " +
+        s"cachedUntilProduced=$cachedSizeUntilLastProduced")
+  }
+
+  /** Remove all cached responses */
+  def removeAll(): Unit = synchronized {
+    removeResponsesUntilIndex(lastProducedIndex)
+    logInfo(
+      s"Release all for opId=${executeHolder.operationId}. Execution stats: " +
+        s"total=${totalSize} " +
+        s"autoRemoved=${autoRemovedSize} " +
+        s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " +
+        s"cachedUntilProduced=$cachedSizeUntilLastProduced " +
+        s"maxCachedUntilConsumed=${cachedSizeUntilHighestConsumed.max} " +
+        s"maxCachedUntilProduced=${cachedSizeUntilLastProduced.max}")
   }
 
   /** Returns if the stream is finished. */
@@ -218,16 +261,31 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
       totalResponsesSize += responses.get(i).get.serializedByteSize
       i -= 1
     }
-    removeResponsesUntilIndex(i)
+    if (responses.get(i).isDefined) {
+      logDebug(
+        s"AUTORELEASE opId=${executeHolder.operationId} until idx=$i. " +
+          s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " +
+          s"cachedUntilProduced=$cachedSizeUntilLastProduced")
+      removeResponsesUntilIndex(i, true)
+    } else {
+      logDebug(
+        s"NO AUTORELEASE opId=${executeHolder.operationId}. " +
+          s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " +
+          s"cachedUntilProduced=$cachedSizeUntilLastProduced")
+    }
   }
 
   /**
    * Remove cached responses until given index. Iterating backwards, once an 
index is encountered
    * that has been removed, all earlier indexes would also be removed.
    */
-  private def removeResponsesUntilIndex(index: Long) = {
+  private def removeResponsesUntilIndex(index: Long, autoRemoved: Boolean = 
false) = {
     var i = index
     while (i >= 1 && responses.get(i).isDefined) {
+      val r = responses.get(i).get
+      cachedSizeUntilHighestConsumed.remove(r)
+      cachedSizeUntilLastProduced.remove(r)
+      if (autoRemoved) autoRemovedSize.add(r)
       responses.remove(i)
       i -= 1
     }
@@ -258,4 +316,26 @@ private[connect] class ExecuteResponseObserver[T <: 
MessageLite](val executeHold
         executePlanResponse.getResponseId
     }
   }
+
+  /**
+   * Helper for counting statistics about cached responses.
+   */
+  private case class CachedSize(var bytes: Long = 0L, var num: Long = 0L) {
+    var maxBytes: Long = 0L
+    var maxNum: Long = 0L
+
+    def add(t: CachedStreamResponse[T]): Unit = {
+      bytes += t.serializedByteSize
+      if (bytes > maxBytes) maxBytes = bytes
+      num += 1
+      if (num > maxNum) maxNum = num
+    }
+
+    def remove(t: CachedStreamResponse[T]): Unit = {
+      bytes -= t.serializedByteSize
+      num -= 1
+    }
+
+    def max: CachedSize = CachedSize(maxBytes, maxNum)
+  }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 930ccae5d4c..62083d4892f 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -222,7 +222,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
       .build()
   }
 
-  private class ExecutionThread extends Thread {
+  private class ExecutionThread
+      extends 
Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") {
     override def run(): Unit = {
       execute()
     }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 4eb90f9f163..105af0dc0ba 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.connect.service
 import java.util.UUID
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
-import org.apache.spark.SparkSQLException
+import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.common.ProtoUtils
+import 
org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_ENABLED
 import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, 
ExecuteResponseObserver, ExecuteThreadRunner}
 import org.apache.spark.util.SystemClock
 
@@ -36,6 +38,8 @@ private[connect] class ExecuteHolder(
     val sessionHolder: SessionHolder)
     extends Logging {
 
+  val session = sessionHolder.session
+
   val operationId = if (request.hasOperationId) {
     try {
       UUID.fromString(request.getOperationId).toString
@@ -73,8 +77,11 @@ private[connect] class ExecuteHolder(
    * If execution is reattachable, it's life cycle is not limited to a single 
ExecutePlanRequest,
    * but can be reattached with ReattachExecute, and released with 
ReleaseExecute
    */
-  val reattachable: Boolean = request.getRequestOptionsList.asScala.exists { 
option =>
-    option.hasReattachOptions && option.getReattachOptions.getReattachable == 
true
+  val reattachable: Boolean = {
+    SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_ENABLED) &&
+    request.getRequestOptionsList.asScala.exists { option =>
+      option.hasReattachOptions && option.getReattachOptions.getReattachable 
== true
+    }
   }
 
   /**
@@ -83,7 +90,12 @@ private[connect] class ExecuteHolder(
    */
   var attached: Boolean = true
 
-  val session = sessionHolder.session
+  /**
+   * Threads that execute the ExecuteGrpcResponseSender and send the GRPC 
responses.
+   *
+   * TODO(SPARK-44625): Joining and cleaning up these threads during cleanup.
+   */
+  val grpcSenderThreads: mutable.ArrayBuffer[Thread] = new 
mutable.ArrayBuffer[Thread]()
 
   val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] =
     new ExecuteResponseObserver[proto.ExecutePlanResponse](this)
@@ -162,6 +174,7 @@ private[connect] class ExecuteHolder(
   def close(): Unit = {
     runner.interrupt()
     runner.join()
+    responseObserver.removeAll()
     eventsManager.postClosed()
     sessionHolder.removeExecuteHolder(operationId)
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
index 0226b4e5ed3..9daf1e17b5e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
@@ -31,20 +31,10 @@ class SparkConnectExecutePlanHandler(responseObserver: 
StreamObserver[proto.Exec
       .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
     val executeHolder = sessionHolder.createExecuteHolder(v)
 
-    try {
-      executeHolder.eventsManager.postStarted()
-      executeHolder.start()
-      val responseSender =
-        new 
ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, 
responseObserver)
-      executeHolder.attachAndRunGrpcResponseSender(responseSender)
-    } finally {
-      if (!executeHolder.reattachable) {
-        // Non reattachable executions release here immediately.
-        executeHolder.close()
-      } else {
-        // Reattachable executions close release with ReleaseExecute RPC.
-        // TODO We mark in the ExecuteHolder that RPC detached.
-      }
-    }
+    executeHolder.eventsManager.postStarted()
+    executeHolder.start()
+    val responseSender =
+      new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, 
responseObserver)
+    executeHolder.attachAndRunGrpcResponseSender(responseSender)
   }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala
index 362846a87b5..b70c82ab137 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala
@@ -44,20 +44,14 @@ class SparkConnectReattachExecuteHandler(
         messageParameters = Map.empty)
     }
 
-    try {
-      val responseSender =
-        new 
ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, 
responseObserver)
-      if (v.hasLastResponseId) {
-        // start from response after lastResponseId
-        executeHolder.attachAndRunGrpcResponseSender(responseSender, 
v.getLastResponseId)
-      } else {
-        // start from the start of the stream.
-        executeHolder.attachAndRunGrpcResponseSender(responseSender)
-      }
-    } finally {
-      // Reattachable executions do not free the execution here, but client 
needs to call
-      // ReleaseExecute RPC.
-      // TODO We mark in the ExecuteHolder that RPC detached.
+    val responseSender =
+      new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, 
responseObserver)
+    if (v.hasLastResponseId) {
+      // start from response after lastResponseId
+      executeHolder.attachAndRunGrpcResponseSender(responseSender, 
v.getLastResponseId)
+    } else {
+      // start from the start of the stream.
+      executeHolder.attachAndRunGrpcResponseSender(responseSender)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to