This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 292a1131b542 [SPARK-44855][CONNECT] Small tweaks to attaching 
ExecuteGrpcResponseSender to ExecuteResponseObserver
292a1131b542 is described below

commit 292a1131b542ddc7b227a7e51e4f4233f3d2f9d8
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Wed Oct 11 15:01:20 2023 -0400

    [SPARK-44855][CONNECT] Small tweaks to attaching ExecuteGrpcResponseSender 
to ExecuteResponseObserver
    
    ### What changes were proposed in this pull request?
    
    Small improvements can be made to the way new ExecuteGrpcResponseSender is 
attached to observer.
    * Since now we have addGrpcResponseSender in ExecuteHolder, it should be 
ExecuteHolder responsibility to interrupt the old sender and that there is only 
one at a time, and to ExecuteResponseObserver's responsibility
    * executeObserver is used as a lock for synchronization. An explicit lock 
object could be better.
    
    Fix a small bug, when ExecuteGrpcResponseSender will not be waken up by 
interrupt if it was sleeping on the grpcCallObserverReadySignal. This would 
result in the sender potentially sleeping until the deadline (2 minutes) and 
only then removed, which would potentially delay timing the execution out by 
these 2 minutes. It should **not** cause any hang or wait on the client side, 
because if ExecuteGrpcResponseSender is interrupted, it means that the client 
has already came back with a ne [...]
    
    ### Why are the changes needed?
    
    Minor cleanup of previous work.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests in ReattachableExecuteSuite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43181 from juliuszsompolski/SPARK-44855.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../execution/ExecuteGrpcResponseSender.scala      | 26 ++++++++-----
 .../execution/ExecuteResponseObserver.scala        | 44 ++++++++++------------
 .../spark/sql/connect/service/ExecuteHolder.scala  |  4 ++
 3 files changed, 40 insertions(+), 34 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
index 08496c36b28a..ba5ecc7a045a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
@@ -63,15 +63,15 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
   /**
    * Interrupt this sender and make it exit.
    */
-  def interrupt(): Unit = executionObserver.synchronized {
+  def interrupt(): Unit = {
     interrupted = true
-    executionObserver.notifyAll()
+    wakeUp()
   }
 
   // For testing
-  private[connect] def setDeadline(deadlineMs: Long) = 
executionObserver.synchronized {
+  private[connect] def setDeadline(deadlineMs: Long) = {
     deadlineTimeMillis = deadlineMs
-    executionObserver.notifyAll()
+    wakeUp()
   }
 
   def run(lastConsumedStreamIndex: Long): Unit = {
@@ -152,9 +152,6 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
         s"lastConsumedStreamIndex=$lastConsumedStreamIndex")
     val startTime = System.nanoTime()
 
-    // register to be notified about available responses.
-    executionObserver.attachConsumer(this)
-
     var nextIndex = lastConsumedStreamIndex + 1
     var finished = false
 
@@ -191,7 +188,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
         sentResponsesSize > maximumResponseSize || deadlineTimeMillis < 
System.currentTimeMillis()
 
       logTrace(s"Trying to get next response with index=$nextIndex.")
-      executionObserver.synchronized {
+      executionObserver.responseLock.synchronized {
         logTrace(s"Acquired executionObserver lock.")
         val sleepStart = System.nanoTime()
         var sleepEnd = 0L
@@ -208,7 +205,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
           if (response.isEmpty) {
             val timeout = Math.max(1, deadlineTimeMillis - 
System.currentTimeMillis())
             logTrace(s"Wait for response to become available with 
timeout=$timeout ms.")
-            executionObserver.wait(timeout)
+            executionObserver.responseLock.wait(timeout)
             logTrace(s"Reacquired executionObserver lock after waiting.")
             sleepEnd = System.nanoTime()
           }
@@ -339,4 +336,15 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
       }
     }
   }
+
+  private def wakeUp(): Unit = {
+    // Can be sleeping on either of these two locks, wake them up.
+    // (Neither of these locks is ever taken for extended period of time, so 
this won't block)
+    executionObserver.responseLock.synchronized {
+      executionObserver.responseLock.notifyAll()
+    }
+    grpcCallObserverReadySignal.synchronized {
+      grpcCallObserverReadySignal.notifyAll()
+    }
+  }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
index 859ec7e6b198..e99e3a94f73a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
@@ -33,8 +33,7 @@ import org.apache.spark.sql.connect.service.ExecuteHolder
 /**
  * This StreamObserver is running on the execution thread. Execution pushes 
responses to it, it
  * caches them. ExecuteResponseGRPCSender is the consumer of the responses 
ExecuteResponseObserver
- * "produces". It waits on the monitor of ExecuteResponseObserver. New 
produced responses notify
- * the monitor.
+ * "produces". It waits on the responseLock. New produced responses notify the 
responseLock.
  * @see
  *   getResponse.
  *
@@ -85,10 +84,12 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   private[connect] var highestConsumedIndex: Long = 0
 
   /**
-   * Consumer that waits for available responses. There can be only one at a 
time, @see
-   * attachConsumer.
+   * Lock used for synchronization between responseObserver and 
grpcResponseSenders. *
+   * grpcResponseSenders wait on it for a new response to be available. * 
grpcResponseSenders also
+   * notify it to wake up when interrupted * responseObserver notifies it when 
new responses are
+   * available.
    */
-  private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None
+  private[connect] val responseLock = new Object()
 
   // Statistics about cached responses.
   private val cachedSizeUntilHighestConsumed = CachedSize()
@@ -106,7 +107,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
     0
   }
 
-  def onNext(r: T): Unit = synchronized {
+  def onNext(r: T): Unit = responseLock.synchronized {
     if (finalProducedIndex.nonEmpty) {
       throw new IllegalStateException("Stream onNext can't be called after 
stream completed")
     }
@@ -125,10 +126,10 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
     logDebug(
       s"Execution opId=${executeHolder.operationId} produced response " +
         s"responseId=${responseId} idx=$lastProducedIndex")
-    notifyAll()
+    responseLock.notifyAll()
   }
 
-  def onError(t: Throwable): Unit = synchronized {
+  def onError(t: Throwable): Unit = responseLock.synchronized {
     if (finalProducedIndex.nonEmpty) {
       throw new IllegalStateException("Stream onError can't be called after 
stream completed")
     }
@@ -137,10 +138,10 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
     logDebug(
       s"Execution opId=${executeHolder.operationId} produced error. " +
         s"Last stream index is $lastProducedIndex.")
-    notifyAll()
+    responseLock.notifyAll()
   }
 
-  def onCompleted(): Unit = synchronized {
+  def onCompleted(): Unit = responseLock.synchronized {
     if (finalProducedIndex.nonEmpty) {
       throw new IllegalStateException("Stream onCompleted can't be called 
after stream completed")
     }
@@ -148,14 +149,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
     logDebug(
       s"Execution opId=${executeHolder.operationId} completed stream. " +
         s"Last stream index is $lastProducedIndex.")
-    notifyAll()
-  }
-
-  /** Attach a new consumer (ExecuteResponseGRPCSender). */
-  def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = 
synchronized {
-    // interrupt the current sender before attaching new one
-    responseSender.foreach(_.interrupt())
-    responseSender = Some(newSender)
+    responseLock.notifyAll()
   }
 
   /**
@@ -163,7 +157,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
    * this response observer assumes that the response is consumed, and the 
response and previous
    * response can be uncached, keeping retryBufferSize of responses for the 
case of retries.
    */
-  def consumeResponse(index: Long): Option[CachedStreamResponse[T]] = 
synchronized {
+  def consumeResponse(index: Long): Option[CachedStreamResponse[T]] = 
responseLock.synchronized {
     // we index stream responses from 1, getting a lower index would be 
invalid.
     assert(index >= 1)
     // it would be invalid if consumer would skip a response
@@ -198,17 +192,17 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   }
 
   /** Get the stream error if there is one, otherwise None. */
-  def getError(): Option[Throwable] = synchronized {
+  def getError(): Option[Throwable] = responseLock.synchronized {
     error
   }
 
   /** If the stream is finished, the index of the last response, otherwise 
None. */
-  def getLastResponseIndex(): Option[Long] = synchronized {
+  def getLastResponseIndex(): Option[Long] = responseLock.synchronized {
     finalProducedIndex
   }
 
   /** Get the index in the stream for given response id. */
-  def getResponseIndexById(responseId: String): Long = synchronized {
+  def getResponseIndexById(responseId: String): Long = 
responseLock.synchronized {
     responseIdToIndex.getOrElse(
       responseId,
       throw new SparkSQLException(
@@ -217,7 +211,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   }
 
   /** Remove cached responses up to and including response with given id. */
-  def removeResponsesUntilId(responseId: String): Unit = synchronized {
+  def removeResponsesUntilId(responseId: String): Unit = 
responseLock.synchronized {
     val index = getResponseIndexById(responseId)
     removeResponsesUntilIndex(index)
     logDebug(
@@ -229,7 +223,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   }
 
   /** Remove all cached responses */
-  def removeAll(): Unit = synchronized {
+  def removeAll(): Unit = responseLock.synchronized {
     removeResponsesUntilIndex(lastProducedIndex)
     logInfo(
       s"Release all for opId=${executeHolder.operationId}. Execution stats: " +
@@ -242,7 +236,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   }
 
   /** Returns if the stream is finished. */
-  def completed(): Boolean = synchronized {
+  def completed(): Boolean = responseLock.synchronized {
     finalProducedIndex.isDefined
   }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 0593edc2f6fd..eed8cc01f7c6 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -164,6 +164,10 @@ private[connect] class ExecuteHolder(
   private def addGrpcResponseSender(
       sender: ExecuteGrpcResponseSender[proto.ExecutePlanResponse]) = 
synchronized {
     if (closedTime.isEmpty) {
+      // Interrupt all other senders - there can be only one active sender.
+      // Interrupted senders will remove themselves with 
removeGrpcResponseSender when they exit.
+      grpcResponseSenders.foreach(_.interrupt())
+      // And add this one.
       grpcResponseSenders += sender
       lastAttachedRpcTime = None
     } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to