This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new af8c0b999be [SPARK-44872][CONNECT] Server testing infra and 
ReattachableExecuteSuite
af8c0b999be is described below

commit af8c0b999be746b661efe2439ac015a0c7d12c00
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Tue Sep 12 16:48:26 2023 +0200

    [SPARK-44872][CONNECT] Server testing infra and ReattachableExecuteSuite
    
    ### What changes were proposed in this pull request?
    
    Add `SparkConnectServerTest` with infra to test real server with real 
client in the same process, but communicating over RPC.
    
    Add `ReattachableExecuteSuite` with some tests for reattachable execute.
    
    Two bugs were found by the tests:
    * Fix bug in `SparkConnectExecutionManager.createExecuteHolder` when 
attempting to resubmit an operation that was deemed abandoned. This bug is 
benign in reattachable execute, because reattachable execute would first send a 
ReattachExecute, which would be handled correctly in 
SparkConnectReattachExecuteHandler. For non-reattachable execute (disabled or 
old client), this is also a very unlikely scenario, because the retrying 
mechanism should be able to resubmit before the query is decl [...]
    * In `ExecuteGrpcResponseSender` there was an assertion that assumed that 
if `sendResponse` did not send, it was because deadline was reached. But it can 
also be because of interrupt. This would have resulted in interrupt returning 
an assertion error instead of CURSOR_DISCONNECTED in testing. Outside of 
testing assertions are not enabled, so this was not a problem outside of 
testing.
    
    ### Why are the changes needed?
    
    Testing of reattachable execute.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Tests added.
    
    Closes #42560 from juliuszsompolski/sc-reattachable-tests.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 4b96add471d292ed5c63ccc625489ff78cfb9b25)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../sql/connect/client/CloseableIterator.scala     |  22 +-
 .../client/CustomSparkConnectBlockingStub.scala    |   2 +-
 .../ExecutePlanResponseReattachableIterator.scala  |  18 +-
 .../connect/client/GrpcExceptionConverter.scala    |   5 +-
 .../sql/connect/client/GrpcRetryHandler.scala      |   4 +-
 .../execution/ExecuteGrpcResponseSender.scala      |  17 +-
 .../execution/ExecuteResponseObserver.scala        |   8 +-
 .../spark/sql/connect/service/ExecuteHolder.scala  |  10 +
 .../service/SparkConnectExecutionManager.scala     |  40 ++-
 .../spark/sql/connect/SparkConnectServerTest.scala | 261 +++++++++++++++
 .../execution/ReattachableExecuteSuite.scala       | 352 +++++++++++++++++++++
 .../scala/org/apache/spark/SparkFunSuite.scala     |  24 ++
 12 files changed, 735 insertions(+), 28 deletions(-)

diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
index 891e50ed6e7..d3fc9963edc 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
@@ -27,6 +27,20 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] 
with AutoCloseable {
   }
 }
 
+private[sql] abstract class WrappedCloseableIterator[E] extends 
CloseableIterator[E] {
+
+  def innerIterator: Iterator[E]
+
+  override def next(): E = innerIterator.next()
+
+  override def hasNext(): Boolean = innerIterator.hasNext
+
+  override def close(): Unit = innerIterator match {
+    case it: CloseableIterator[E] => it.close()
+    case _ => // nothing
+  }
+}
+
 private[sql] object CloseableIterator {
 
   /**
@@ -35,12 +49,8 @@ private[sql] object CloseableIterator {
   def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match {
     case closeable: CloseableIterator[T] => closeable
     case _ =>
-      new CloseableIterator[T] {
-        override def next(): T = iterator.next()
-
-        override def hasNext(): Boolean = iterator.hasNext
-
-        override def close() = { /* empty */ }
+      new WrappedCloseableIterator[T] {
+        override def innerIterator = iterator
       }
   }
 }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 73ff01e223f..80edcfa8be1 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -22,7 +22,7 @@ import io.grpc.ManagedChannel
 
 import org.apache.spark.connect.proto._
 
-private[client] class CustomSparkConnectBlockingStub(
+private[connect] class CustomSparkConnectBlockingStub(
     channel: ManagedChannel,
     retryPolicy: GrpcRetryHandler.RetryPolicy) {
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index 9bf7de33da8..57a629264be 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client
 
 import java.util.UUID
 
+import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
 import io.grpc.{ManagedChannel, StatusRuntimeException}
@@ -50,7 +51,7 @@ class ExecutePlanResponseReattachableIterator(
     request: proto.ExecutePlanRequest,
     channel: ManagedChannel,
     retryPolicy: GrpcRetryHandler.RetryPolicy)
-    extends CloseableIterator[proto.ExecutePlanResponse]
+    extends WrappedCloseableIterator[proto.ExecutePlanResponse]
     with Logging {
 
   val operationId = if (request.hasOperationId) {
@@ -86,14 +87,25 @@ class ExecutePlanResponseReattachableIterator(
   // True after ResultComplete message was seen in the stream.
   // Server will always send this message at the end of the stream, if the 
underlying iterator
   // finishes without producing one, another iterator needs to be reattached.
-  private var resultComplete: Boolean = false
+  // Visible for testing.
+  private[connect] var resultComplete: Boolean = false
 
   // Initial iterator comes from ExecutePlan request.
   // Note: This is not retried, because no error would ever be thrown here, 
and GRPC will only
   // throw error on first iter.hasNext() or iter.next()
-  private var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] =
+  // Visible for testing.
+  private[connect] var iter: 
Option[java.util.Iterator[proto.ExecutePlanResponse]] =
     Some(rawBlockingStub.executePlan(initialRequest))
 
+  override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match 
{
+    case Some(it) => it.asScala
+    case None =>
+      // The iterator is only unset for short moments while retry exception is 
thrown.
+      // It should only happen in the middle of internal processing. Since 
this iterator is not
+      // thread safe, no-one should be accessing it at this moment.
+      throw new IllegalStateException("innerIterator unset")
+  }
+
   override def next(): proto.ExecutePlanResponse = synchronized {
     // hasNext will trigger reattach in case the stream completed without 
resultComplete
     if (!hasNext()) {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index c430485bd41..fe9f6dc2b4a 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -43,7 +43,10 @@ private[client] object GrpcExceptionConverter extends 
JsonUtils {
   }
 
   def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = {
-    new CloseableIterator[T] {
+    new WrappedCloseableIterator[T] {
+
+      override def innerIterator: Iterator[T] = iter
+
       override def hasNext: Boolean = {
         convert {
           iter.hasNext
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 8791530607c..3c0b750fd46 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -48,11 +48,13 @@ private[sql] class GrpcRetryHandler(
    *   The type of the response.
    */
   class RetryIterator[T, U](request: T, call: T => CloseableIterator[U])
-      extends CloseableIterator[U] {
+      extends WrappedCloseableIterator[U] {
 
     private var opened = false // we only retry if it fails on first call when 
using the iterator
     private var iter = call(request)
 
+    override def innerIterator: Iterator[U] = iter
+
     private def retryIter[V](f: Iterator[U] => V) = {
       if (!opened) {
         opened = true
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
index 6b8fcde1156..c3c33a85d65 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
@@ -47,6 +47,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
 
   private var interrupted = false
 
+  // Time at which this sender should finish if the response stream is not 
finished by then.
+  private var deadlineTimeMillis = Long.MaxValue
+
   // Signal to wake up when grpcCallObserver.isReady()
   private val grpcCallObserverReadySignal = new Object
 
@@ -65,6 +68,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
     executionObserver.notifyAll()
   }
 
+  // For testing
+  private[connect] def setDeadline(deadlineMs: Long) = 
executionObserver.synchronized {
+    deadlineTimeMillis = deadlineMs
+    executionObserver.notifyAll()
+  }
+
   def run(lastConsumedStreamIndex: Long): Unit = {
     if (executeHolder.reattachable) {
       // In reattachable execution we use setOnReadyHandler and 
grpcCallObserver.isReady to control
@@ -150,7 +159,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
     var finished = false
 
     // Time at which this sender should finish if the response stream is not 
finished by then.
-    val deadlineTimeMillis = if (!executeHolder.reattachable) {
+    deadlineTimeMillis = if (!executeHolder.reattachable) {
       Long.MaxValue
     } else {
       val confSize =
@@ -232,8 +241,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
           assert(finished == false)
         } else {
           // If it wasn't sent, time deadline must have been reached before 
stream became available,
-          // will exit in the enxt loop iterattion.
-          assert(deadlineLimitReached)
+          // or it was intterupted. Will exit in the next loop iterattion.
+          assert(deadlineLimitReached || interrupted)
         }
       } else if (streamFinished) {
         // Stream is finished and all responses have been sent
@@ -301,7 +310,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
         val sleepStart = System.nanoTime()
         var sleepEnd = 0L
         // Conditions for exiting the inner loop
-        // 1. was detached
+        // 1. was interrupted
         // 2. grpcCallObserver is ready to send more data
         // 3. time deadline is reached
         while (!interrupted &&
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
index d9db07fd228..df0fb3ac3a5 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
@@ -73,11 +73,16 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
   /** The index of the last response produced by execution. */
   private var lastProducedIndex: Long = 0 // first response will have index 1
 
+  // For testing
+  private[connect] var releasedUntilIndex: Long = 0
+
   /**
    * Highest response index that was consumed. Keeps track of it to decide 
which responses needs
    * to be cached, and to assert that all responses are consumed.
+   *
+   * Visible for testing.
    */
-  private var highestConsumedIndex: Long = 0
+  private[connect] var highestConsumedIndex: Long = 0
 
   /**
    * Consumer that waits for available responses. There can be only one at a 
time, @see
@@ -284,6 +289,7 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
       responses.remove(i)
       i -= 1
     }
+    releasedUntilIndex = index
   }
 
   /**
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index bce07133392..974c13b08e3 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -183,6 +183,16 @@ private[connect] class ExecuteHolder(
     }
   }
 
+  // For testing.
+  private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = 
synchronized {
+    grpcResponseSenders.foreach(_.setDeadline(deadlineMs))
+  }
+
+  // For testing
+  private[connect] def interruptGrpcResponseSenders() = synchronized {
+    grpcResponseSenders.foreach(_.interrupt())
+  }
+
   /**
    * For a short period in ExecutePlan after creation and until 
runGrpcResponseSender is called,
    * there is no attached response sender, but yet we start with 
lastAttachedRpcTime = None, so we
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index ce1f6c93f6c..21f59bdd68e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -71,15 +71,14 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
       // The latter is to prevent double execution when a client retries 
execution, thinking it
       // never reached the server, but in fact it did, and already got removed 
as abandoned.
       if (executions.get(executeHolder.key).isDefined) {
-        if (getAbandonedTombstone(executeHolder.key).isDefined) {
-          throw new SparkSQLException(
-            errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
-            messageParameters = Map("handle" -> executeHolder.operationId))
-        } else {
-          throw new SparkSQLException(
-            errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
-            messageParameters = Map("handle" -> executeHolder.operationId))
-        }
+        throw new SparkSQLException(
+          errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
+          messageParameters = Map("handle" -> executeHolder.operationId))
+      }
+      if (getAbandonedTombstone(executeHolder.key).isDefined) {
+        throw new SparkSQLException(
+          errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
+          messageParameters = Map("handle" -> executeHolder.operationId))
       }
       sessionHolder.addExecuteHolder(executeHolder)
       executions.put(executeHolder.key, executeHolder)
@@ -141,12 +140,17 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     abandonedTombstones.asMap.asScala.values.toBuffer.toSeq
   }
 
-  private[service] def shutdown(): Unit = executionsLock.synchronized {
+  private[connect] def shutdown(): Unit = executionsLock.synchronized {
     scheduledExecutor.foreach { executor =>
       executor.shutdown()
       executor.awaitTermination(1, TimeUnit.MINUTES)
     }
     scheduledExecutor = None
+    executions.clear()
+    abandonedTombstones.invalidateAll()
+    if (!lastExecutionTime.isDefined) {
+      lastExecutionTime = Some(System.currentTimeMillis())
+    }
   }
 
   /**
@@ -188,7 +192,7 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
       executions.values.foreach { executeHolder =>
         executeHolder.lastAttachedRpcTime match {
           case Some(detached) =>
-            if (detached + timeout < nowMs) {
+            if (detached + timeout <= nowMs) {
               toRemove += executeHolder
             }
           case _ => // execution is active
@@ -206,4 +210,18 @@ private[connect] class SparkConnectExecutionManager() 
extends Logging {
     }
     logInfo("Finished periodic run of SparkConnectExecutionManager 
maintenance.")
   }
+
+  // For testing.
+  private[connect] def setAllRPCsDeadline(deadlineMs: Long) = 
executionsLock.synchronized {
+    executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
+  }
+
+  // For testing.
+  private[connect] def interruptAllRPCs() = executionsLock.synchronized {
+    executions.values.foreach(_.interruptGrpcResponseSenders())
+  }
+
+  private[connect] def listExecuteHolders = executionsLock.synchronized {
+    executions.values.toBuffer.toSeq
+  }
 }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
new file mode 100644
index 00000000000..488858d33ea
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect
+
+import java.util.UUID
+
+import org.scalatest.concurrent.{Eventually, TimeLimits}
+import org.scalatest.time.Span
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{CloseableIterator, 
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, 
GrpcRetryHandler, SparkConnectClient, WrappedCloseableIterator}
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.plans._
+import org.apache.spark.sql.connect.service.{ExecuteHolder, 
SparkConnectService}
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Base class and utilities for a test suite that starts and tests the real 
SparkConnectService
+ * with a real SparkConnectClient, communicating over RPC, but both in-process.
+ */
+class SparkConnectServerTest extends SharedSparkSession {
+
+  // Server port
+  val serverPort: Int =
+    ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
+
+  val eventuallyTimeout = 30.seconds
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // Other suites using mocks leave a mess in the global executionManager,
+    // shut it down so that it's cleared before starting server.
+    SparkConnectService.executionManager.shutdown()
+    // Start the real service.
+    withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, 
serverPort.toString)) {
+      SparkConnectService.start(spark.sparkContext)
+    }
+    // register udf directly on the server, we're not testing client UDFs 
here...
+    val serverSession =
+      SparkConnectService.getOrCreateIsolatedSession(defaultUserId, 
defaultSessionId).session
+    serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms 
}))
+  }
+
+  override def afterAll(): Unit = {
+    SparkConnectService.stop()
+    super.afterAll()
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    clearAllExecutions()
+  }
+
+  override def afterEach(): Unit = {
+    clearAllExecutions()
+    super.afterEach()
+  }
+
+  protected def clearAllExecutions(): Unit = {
+    SparkConnectService.executionManager.listExecuteHolders.foreach(_.close())
+    SparkConnectService.executionManager.periodicMaintenance(0)
+    assertNoActiveExecutions()
+  }
+
+  protected val defaultSessionId = UUID.randomUUID.toString()
+  protected val defaultUserId = UUID.randomUUID.toString()
+
+  // We don't have the real SparkSession/Dataset api available,
+  // so use mock for generating simple query plans.
+  protected val dsl = new MockRemoteSession()
+
+  protected val userContext = proto.UserContext
+    .newBuilder()
+    .setUserId(defaultUserId)
+    .build()
+
+  protected def buildExecutePlanRequest(
+      plan: proto.Plan,
+      sessionId: String = defaultSessionId,
+      operationId: String = UUID.randomUUID.toString) = {
+    proto.ExecutePlanRequest
+      .newBuilder()
+      .setUserContext(userContext)
+      .setSessionId(sessionId)
+      .setOperationId(operationId)
+      .setPlan(plan)
+      .addRequestOptions(
+        proto.ExecutePlanRequest.RequestOption
+          .newBuilder()
+          
.setReattachOptions(proto.ReattachOptions.newBuilder().setReattachable(true).build())
+          .build())
+      .build()
+  }
+
+  protected def buildReattachExecuteRequest(operationId: String, responseId: 
Option[String]) = {
+    val req = proto.ReattachExecuteRequest
+      .newBuilder()
+      .setUserContext(userContext)
+      .setSessionId(defaultSessionId)
+      .setOperationId(operationId)
+
+    if (responseId.isDefined) {
+      req.setLastResponseId(responseId.get)
+    }
+
+    req.build()
+  }
+
+  protected def buildPlan(query: String) = {
+    proto.Plan.newBuilder().setRoot(dsl.sql(query)).build()
+  }
+
+  protected def getReattachableIterator(
+      stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = {
+    // This depends on the wrapping in 
CustomSparkConnectBlockingStub.executePlanReattachable:
+    // GrpcExceptionConverter.convertIterator
+    stubIterator
+      .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]]
+      // ExecutePlanResponseReattachableIterator
+      .innerIterator
+      .asInstanceOf[ExecutePlanResponseReattachableIterator]
+  }
+
+  protected def assertNoActiveRpcs(): Unit = {
+    SparkConnectService.executionManager.listActiveExecutions match {
+      case Left(_) => // nothing running, good
+      case Right(executions) =>
+        // all rpc detached.
+        assert(
+          executions.forall(_.lastAttachedRpcTime.isDefined),
+          s"Expected no RPCs, but got $executions")
+    }
+  }
+
+  protected def assertEventuallyNoActiveRpcs(): Unit = {
+    Eventually.eventually(timeout(eventuallyTimeout)) {
+      assertNoActiveRpcs()
+    }
+  }
+
+  protected def assertNoActiveExecutions(): Unit = {
+    SparkConnectService.executionManager.listActiveExecutions match {
+      case Left(_) => // cleaned up
+      case Right(executions) => fail(s"Expected empty, but got $executions")
+    }
+  }
+
+  protected def assertEventuallyNoActiveExecutions(): Unit = {
+    Eventually.eventually(timeout(eventuallyTimeout)) {
+      assertNoActiveExecutions()
+    }
+  }
+
+  protected def assertExecutionReleased(operationId: String): Unit = {
+    SparkConnectService.executionManager.listActiveExecutions match {
+      case Left(_) => // cleaned up
+      case Right(executions) => assert(!executions.exists(_.operationId == 
operationId))
+    }
+  }
+
+  protected def assertEventuallyExecutionReleased(operationId: String): Unit = 
{
+    Eventually.eventually(timeout(eventuallyTimeout)) {
+      assertExecutionReleased(operationId)
+    }
+  }
+
+  // Get ExecutionHolder, assuming that only one execution is active
+  protected def getExecutionHolder: ExecuteHolder = {
+    val executions = SparkConnectService.executionManager.listExecuteHolders
+    assert(executions.length == 1)
+    executions.head
+  }
+
+  protected def withClient(f: SparkConnectClient => Unit): Unit = {
+    val client = SparkConnectClient
+      .builder()
+      .port(serverPort)
+      .sessionId(defaultSessionId)
+      .userId(defaultUserId)
+      .enableReattachableExecute()
+      .build()
+    try f(client)
+    finally {
+      client.shutdown()
+    }
+  }
+
+  protected def withRawBlockingStub(
+      f: proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub => 
Unit): Unit = {
+    val conf = SparkConnectClient.Configuration(port = serverPort)
+    val channel = conf.createChannel()
+    val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
+    try f(bstub)
+    finally {
+      channel.shutdownNow()
+    }
+  }
+
+  protected def withCustomBlockingStub(
+      retryPolicy: GrpcRetryHandler.RetryPolicy = 
GrpcRetryHandler.RetryPolicy())(
+      f: CustomSparkConnectBlockingStub => Unit): Unit = {
+    val conf = SparkConnectClient.Configuration(port = serverPort)
+    val channel = conf.createChannel()
+    val bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy)
+    try f(bstub)
+    finally {
+      channel.shutdownNow()
+    }
+  }
+
+  protected def runQuery(plan: proto.Plan, queryTimeout: Span, iterSleep: 
Long): Unit = {
+    withClient { client =>
+      TimeLimits.failAfter(queryTimeout) {
+        val iter = client.execute(plan)
+        var operationId: Option[String] = None
+        var r: proto.ExecutePlanResponse = null
+        val reattachableIter = getReattachableIterator(iter)
+        while (iter.hasNext) {
+          r = iter.next()
+          operationId match {
+            case None => operationId = Some(r.getOperationId)
+            case Some(id) => assert(r.getOperationId == id)
+          }
+          if (iterSleep > 0) {
+            Thread.sleep(iterSleep)
+          }
+        }
+        // Check that last response had ResultComplete indicator
+        assert(r != null)
+        assert(r.hasResultComplete)
+        // ... that client sent ReleaseExecute based on it
+        assert(reattachableIter.resultComplete)
+        // ... and that the server released the execution.
+        assert(operationId.isDefined)
+        assertEventuallyExecutionReleased(operationId.get)
+      }
+    }
+  }
+
+  protected def runQuery(query: String, queryTimeout: Span, iterSleep: Long = 
0): Unit = {
+    val plan = buildPlan(query)
+    runQuery(plan, queryTimeout, iterSleep)
+  }
+}
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
new file mode 100644
index 00000000000..169b15582b6
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.execution
+
+import java.util.UUID
+
+import io.grpc.StatusRuntimeException
+import org.scalatest.concurrent.Eventually
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.connect.SparkConnectServerTest
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.service.SparkConnectService
+
+class ReattachableExecuteSuite extends SparkConnectServerTest {
+
+  // Tests assume that this query will result in at least a couple 
ExecutePlanResponses on the
+  // stream. If this is no longer the case because of changes in how much is 
returned in a single
+  // ExecutePlanResponse, it may need to be adjusted.
+  val MEDIUM_RESULTS_QUERY = "select * from range(1000000)"
+
+  test("reattach after initial RPC ends") {
+    withClient { client =>
+      val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
+      val reattachableIter = getReattachableIterator(iter)
+      val initialInnerIter = reattachableIter.innerIterator
+
+      // open the iterator
+      iter.next()
+      // expire all RPCs on server
+      
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis()
 - 1)
+      assertEventuallyNoActiveRpcs()
+      // iterator should reattach
+      // (but not necessarily at first next, as there might have been messages 
buffered client side)
+      while (iter.hasNext && (reattachableIter.innerIterator eq 
initialInnerIter)) {
+        iter.next()
+      }
+      assert(
+        reattachableIter.innerIterator ne initialInnerIter
+      ) // reattach changed the inner iter
+    }
+  }
+
+  test("raw interrupted RPC results in INVALID_CURSOR.DISCONNECTED error") {
+    withRawBlockingStub { stub =>
+      val iter = 
stub.executePlan(buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY)))
+      iter.next() // open the iterator
+      // interrupt all RPCs on server
+      SparkConnectService.executionManager.interruptAllRPCs()
+      assertEventuallyNoActiveRpcs()
+      val e = intercept[StatusRuntimeException] {
+        while (iter.hasNext) iter.next()
+      }
+      assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED"))
+    }
+  }
+
+  test("raw new RPC interrupts previous RPC with INVALID_CURSOR.DISCONNECTED 
error") {
+    // Raw stub does not have retries, auto reattach etc.
+    withRawBlockingStub { stub =>
+      val operationId = UUID.randomUUID().toString
+      val iter = stub.executePlan(
+        buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = 
operationId))
+      iter.next() // open the iterator
+
+      // send reattach
+      val iter2 = 
stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
+      iter2.next() // open the iterator
+
+      // should result in INVALID_CURSOR.DISCONNECTED error on the original 
iterator
+      val e = intercept[StatusRuntimeException] {
+        while (iter.hasNext) iter.next()
+      }
+      assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED"))
+
+      // send another reattach
+      val iter3 = 
stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
+      assert(iter3.hasNext)
+      iter3.next() // open the iterator
+
+      // should result in INVALID_CURSOR.DISCONNECTED error on the previous 
reattach iterator
+      val e2 = intercept[StatusRuntimeException] {
+        while (iter2.hasNext) iter2.next()
+      }
+      assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED"))
+    }
+  }
+
+  test("client INVALID_CURSOR.DISCONNECTED error is retried when rpc sender 
gets interrupted") {
+    withClient { client =>
+      val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
+      val reattachableIter = getReattachableIterator(iter)
+      val initialInnerIter = reattachableIter.innerIterator
+      val operationId = getReattachableIterator(iter).operationId
+
+      // open the iterator
+      iter.next()
+
+      // interrupt all RPCs on server
+      SparkConnectService.executionManager.interruptAllRPCs()
+      assertEventuallyNoActiveRpcs()
+
+      // Nevertheless, the original iterator will handle the 
INVALID_CURSOR.DISCONNECTED error
+      iter.next()
+      // iterator changed because it had to reconnect
+      assert(reattachableIter.innerIterator ne initialInnerIter)
+    }
+  }
+
+  test("client INVALID_CURSOR.DISCONNECTED error is retried when other RPC 
preempts this one") {
+    withClient { client =>
+      val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
+      val reattachableIter = getReattachableIterator(iter)
+      val initialInnerIter = reattachableIter.innerIterator
+      val operationId = getReattachableIterator(iter).operationId
+
+      // open the iterator
+      val response = iter.next()
+
+      // Send another Reattach request, it should preempt this request with an
+      // INVALID_CURSOR.DISCONNECTED error.
+      withRawBlockingStub { stub =>
+        val reattachIter = stub.reattachExecute(
+          buildReattachExecuteRequest(operationId, 
Some(response.getResponseId)))
+        assert(reattachIter.hasNext)
+        reattachIter.next()
+
+        // Nevertheless, the original iterator will handle the 
INVALID_CURSOR.DISCONNECTED error
+        iter.next()
+        // iterator changed because it had to reconnect
+        assert(reattachableIter.innerIterator ne initialInnerIter)
+      }
+    }
+  }
+
+  test("abandoned query gets INVALID_HANDLE.OPERATION_ABANDONED error") {
+    withClient { client =>
+      val plan = buildPlan("select * from range(100000)")
+      val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
+      val operationId = getReattachableIterator(iter).operationId
+      // open the iterator
+      iter.next()
+      // disconnect and remove on server
+      
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis()
 - 1)
+      assertEventuallyNoActiveRpcs()
+      SparkConnectService.executionManager.periodicMaintenance(0)
+      assertNoActiveExecutions()
+      // check that it throws abandoned error
+      val e = intercept[SparkException] {
+        while (iter.hasNext) iter.next()
+      }
+      assert(e.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+      // check that afterwards, new operation can't be created with the same 
operationId.
+      withCustomBlockingStub() { stub =>
+        val executePlanReq = buildExecutePlanRequest(plan, operationId = 
operationId)
+
+        val iterNonReattachable = stub.executePlan(executePlanReq)
+        val eNonReattachable = intercept[SparkException] {
+          iterNonReattachable.hasNext
+        }
+        
assert(eNonReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+
+        val iterReattachable = stub.executePlanReattachable(executePlanReq)
+        val eReattachable = intercept[SparkException] {
+          iterReattachable.hasNext
+        }
+        
assert(eReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
+      }
+    }
+  }
+
+  test("client releases responses directly after consuming them") {
+    withClient { client =>
+      val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
+      val reattachableIter = getReattachableIterator(iter)
+      val initialInnerIter = reattachableIter.innerIterator
+      val operationId = getReattachableIterator(iter).operationId
+
+      assert(iter.hasNext) // open iterator
+      val execution = getExecutionHolder
+      assert(execution.responseObserver.releasedUntilIndex == 0)
+
+      // get two responses, check on the server that ReleaseExecute releases 
them afterwards
+      val response1 = iter.next()
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(execution.responseObserver.releasedUntilIndex == 1)
+      }
+
+      val response2 = iter.next()
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(execution.responseObserver.releasedUntilIndex == 2)
+      }
+
+      withRawBlockingStub { stub =>
+        // Reattach after response1 should fail with 
INVALID_CURSOR.POSITION_NOT_AVAILABLE
+        val reattach1 = stub.reattachExecute(
+          buildReattachExecuteRequest(operationId, 
Some(response1.getResponseId)))
+        val e = intercept[StatusRuntimeException] {
+          reattach1.hasNext()
+        }
+        assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE"))
+
+        // Reattach after response2 should work
+        val reattach2 = stub.reattachExecute(
+          buildReattachExecuteRequest(operationId, 
Some(response2.getResponseId)))
+        val response3 = reattach2.next()
+        val response4 = reattach2.next()
+        val response5 = reattach2.next()
+
+        // The original client iterator will handle the 
INVALID_CURSOR.DISCONNECTED error,
+        // and reconnect back. Since the raw iterator was not releasing 
responses, client iterator
+        // should be able to continue where it left off (server shouldn't have 
released yet)
+        assert(execution.responseObserver.releasedUntilIndex == 2)
+        assert(iter.hasNext)
+
+        val r3 = iter.next()
+        assert(r3.getResponseId == response3.getResponseId)
+        val r4 = iter.next()
+        assert(r4.getResponseId == response4.getResponseId)
+        val r5 = iter.next()
+        assert(r5.getResponseId == response5.getResponseId)
+        // inner iterator changed because it had to reconnect
+        assert(reattachableIter.innerIterator ne initialInnerIter)
+      }
+    }
+  }
+
+  test("server releases responses automatically when client moves ahead") {
+    withRawBlockingStub { stub =>
+      val operationId = UUID.randomUUID().toString
+      val iter = stub.executePlan(
+        buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = 
operationId))
+      var lastSeenResponse: String = null
+
+      iter.hasNext // open iterator
+      val execution = getExecutionHolder
+
+      // after consuming enough from the iterator, server should automatically 
start releasing
+      var lastSeenIndex = 0
+      while (iter.hasNext && execution.responseObserver.releasedUntilIndex == 
0) {
+        val r = iter.next()
+        lastSeenResponse = r.getResponseId()
+        lastSeenIndex += 1
+      }
+      assert(iter.hasNext)
+      assert(execution.responseObserver.releasedUntilIndex > 0)
+
+      // Reattach from the beginning is not available.
+      val reattach = 
stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
+      val e = intercept[StatusRuntimeException] {
+        reattach.hasNext()
+      }
+      assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE"))
+
+      // Original iterator got disconnected by the reattach and gets 
INVALID_CURSOR.DISCONNECTED
+      val e2 = intercept[StatusRuntimeException] {
+        while (iter.hasNext) iter.next()
+      }
+      assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED"))
+
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        // Even though we didn't consume more from the iterator, the server 
thinks that
+        // it sent more, because GRPC stream onNext() can push into internal 
GRPC buffer without
+        // client picking it up.
+        assert(execution.responseObserver.highestConsumedIndex > lastSeenIndex)
+      }
+      // but CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE is big 
enough that the last
+      // response we've seen is still in range
+      assert(execution.responseObserver.releasedUntilIndex < lastSeenIndex)
+
+      // and a new reattach can continue after what there.
+      val reattach2 =
+        stub.reattachExecute(buildReattachExecuteRequest(operationId, 
Some(lastSeenResponse)))
+      assert(reattach2.hasNext)
+      while (reattach2.hasNext) reattach2.next()
+    }
+  }
+
+  // A few integration tests with large results.
+  // They should run significantly faster than the LARGE_QUERY_TIMEOUT
+  // - big query (4 seconds, 871 milliseconds)
+  // - big query and slow client (7 seconds, 288 milliseconds)
+  // - big query with frequent reattach (1 second, 527 milliseconds)
+  // - big query with frequent reattach and slow client (7 seconds, 365 
milliseconds)
+  // - long sleeping query (10 seconds, 805 milliseconds)
+
+  // intentionally smaller than 
CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION,
+  // so that reattach deadline doesn't "unstuck" if something got stuck.
+  val LARGE_QUERY_TIMEOUT = 100.seconds
+
+  val LARGE_RESULTS_QUERY = s"select id, " +
+    (1 to 20).map(i => s"cast(id as string) c$i").mkString(", ") +
+    s" from range(1000000)"
+
+  test("big query") {
+    // regular query with large results
+    runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT)
+    // Check that execution is released on the server.
+    assertEventuallyNoActiveExecutions()
+  }
+
+  test("big query and slow client") {
+    // regular query with large results, but client is slow so sender will 
need to control flow
+    runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50)
+    // Check that execution is released on the server.
+    assertEventuallyNoActiveExecutions()
+  }
+
+  test("big query with frequent reattach") {
+    // will reattach every 100kB
+    
withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key,
 "100k")) {
+      runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT)
+      // Check that execution is released on the server.
+      assertEventuallyNoActiveExecutions()
+    }
+  }
+
+  test("big query with frequent reattach and slow client") {
+    // will reattach every 100kB, and in addition the client is slow,
+    // so sender will need to control flow
+    
withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key,
 "100k")) {
+      runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50)
+      // Check that execution is released on the server.
+      assertEventuallyNoActiveExecutions()
+    }
+  }
+
+  test("long sleeping query") {
+    // query will be sleeping and not returning results, while having multiple 
reattach
+    withSparkEnvConfs(
+      (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, 
"1s")) {
+      runQuery("select sleep(10000) as s", 30.seconds)
+      // Check that execution is released on the server.
+      assertEventuallyNoActiveExecutions()
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala 
b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index f5819b95087..1163088c82a 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -302,6 +302,30 @@ abstract class SparkFunSuite
     }
   }
 
+  /**
+   * Sets all configurations specified in `pairs` in SparkEnv SparkConf, calls 
`f`, and then
+   * restores all configurations.
+   */
+  protected def withSparkEnvConfs(pairs: (String, String)*)(f: => Unit): Unit 
= {
+    val conf = SparkEnv.get.conf
+    val (keys, values) = pairs.unzip
+    val currentValues = keys.map { key =>
+      if (conf.getOption(key).isDefined) {
+        Some(conf.get(key))
+      } else {
+        None
+      }
+    }
+    pairs.foreach { kv => conf.set(kv._1, kv._2) }
+    try f
+    finally {
+      keys.zip(currentValues).foreach {
+        case (key, Some(value)) => conf.set(key, value)
+        case (key, None) => conf.remove(key)
+      }
+    }
+  }
+
   /**
    * Checks an exception with an error class against expected results.
    * @param exception     The exception to check


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to