This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new af8c0b999be [SPARK-44872][CONNECT] Server testing infra and ReattachableExecuteSuite af8c0b999be is described below commit af8c0b999be746b661efe2439ac015a0c7d12c00 Author: Juliusz Sompolski <ju...@databricks.com> AuthorDate: Tue Sep 12 16:48:26 2023 +0200 [SPARK-44872][CONNECT] Server testing infra and ReattachableExecuteSuite ### What changes were proposed in this pull request? Add `SparkConnectServerTest` with infra to test real server with real client in the same process, but communicating over RPC. Add `ReattachableExecuteSuite` with some tests for reattachable execute. Two bugs were found by the tests: * Fix bug in `SparkConnectExecutionManager.createExecuteHolder` when attempting to resubmit an operation that was deemed abandoned. This bug is benign in reattachable execute, because reattachable execute would first send a ReattachExecute, which would be handled correctly in SparkConnectReattachExecuteHandler. For non-reattachable execute (disabled or old client), this is also a very unlikely scenario, because the retrying mechanism should be able to resubmit before the query is decl [...] * In `ExecuteGrpcResponseSender` there was an assertion that assumed that if `sendResponse` did not send, it was because deadline was reached. But it can also be because of interrupt. This would have resulted in interrupt returning an assertion error instead of CURSOR_DISCONNECTED in testing. Outside of testing assertions are not enabled, so this was not a problem outside of testing. ### Why are the changes needed? Testing of reattachable execute. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests added. Closes #42560 from juliuszsompolski/sc-reattachable-tests. Authored-by: Juliusz Sompolski <ju...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 4b96add471d292ed5c63ccc625489ff78cfb9b25) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../sql/connect/client/CloseableIterator.scala | 22 +- .../client/CustomSparkConnectBlockingStub.scala | 2 +- .../ExecutePlanResponseReattachableIterator.scala | 18 +- .../connect/client/GrpcExceptionConverter.scala | 5 +- .../sql/connect/client/GrpcRetryHandler.scala | 4 +- .../execution/ExecuteGrpcResponseSender.scala | 17 +- .../execution/ExecuteResponseObserver.scala | 8 +- .../spark/sql/connect/service/ExecuteHolder.scala | 10 + .../service/SparkConnectExecutionManager.scala | 40 ++- .../spark/sql/connect/SparkConnectServerTest.scala | 261 +++++++++++++++ .../execution/ReattachableExecuteSuite.scala | 352 +++++++++++++++++++++ .../scala/org/apache/spark/SparkFunSuite.scala | 24 ++ 12 files changed, 735 insertions(+), 28 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala index 891e50ed6e7..d3fc9963edc 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala @@ -27,6 +27,20 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { } } +private[sql] abstract class WrappedCloseableIterator[E] extends CloseableIterator[E] { + + def innerIterator: Iterator[E] + + override def next(): E = innerIterator.next() + + override def hasNext(): Boolean = innerIterator.hasNext + + override def close(): Unit = innerIterator match { + case it: CloseableIterator[E] => it.close() + case _ => // nothing + } +} + private[sql] object CloseableIterator { /** @@ -35,12 +49,8 @@ private[sql] object CloseableIterator { def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match { case closeable: CloseableIterator[T] => closeable case _ => - new CloseableIterator[T] { - override def next(): T = iterator.next() - - override def hasNext(): Boolean = iterator.hasNext - - override def close() = { /* empty */ } + new WrappedCloseableIterator[T] { + override def innerIterator = iterator } } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index 73ff01e223f..80edcfa8be1 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -22,7 +22,7 @@ import io.grpc.ManagedChannel import org.apache.spark.connect.proto._ -private[client] class CustomSparkConnectBlockingStub( +private[connect] class CustomSparkConnectBlockingStub( channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 9bf7de33da8..57a629264be 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.UUID +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import io.grpc.{ManagedChannel, StatusRuntimeException} @@ -50,7 +51,7 @@ class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, retryPolicy: GrpcRetryHandler.RetryPolicy) - extends CloseableIterator[proto.ExecutePlanResponse] + extends WrappedCloseableIterator[proto.ExecutePlanResponse] with Logging { val operationId = if (request.hasOperationId) { @@ -86,14 +87,25 @@ class ExecutePlanResponseReattachableIterator( // True after ResultComplete message was seen in the stream. // Server will always send this message at the end of the stream, if the underlying iterator // finishes without producing one, another iterator needs to be reattached. - private var resultComplete: Boolean = false + // Visible for testing. + private[connect] var resultComplete: Boolean = false // Initial iterator comes from ExecutePlan request. // Note: This is not retried, because no error would ever be thrown here, and GRPC will only // throw error on first iter.hasNext() or iter.next() - private var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = + // Visible for testing. + private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = Some(rawBlockingStub.executePlan(initialRequest)) + override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match { + case Some(it) => it.asScala + case None => + // The iterator is only unset for short moments while retry exception is thrown. + // It should only happen in the middle of internal processing. Since this iterator is not + // thread safe, no-one should be accessing it at this moment. + throw new IllegalStateException("innerIterator unset") + } + override def next(): proto.ExecutePlanResponse = synchronized { // hasNext will trigger reattach in case the stream completed without resultComplete if (!hasNext()) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index c430485bd41..fe9f6dc2b4a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -43,7 +43,10 @@ private[client] object GrpcExceptionConverter extends JsonUtils { } def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = { - new CloseableIterator[T] { + new WrappedCloseableIterator[T] { + + override def innerIterator: Iterator[T] = iter + override def hasNext: Boolean = { convert { iter.hasNext diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 8791530607c..3c0b750fd46 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -48,11 +48,13 @@ private[sql] class GrpcRetryHandler( * The type of the response. */ class RetryIterator[T, U](request: T, call: T => CloseableIterator[U]) - extends CloseableIterator[U] { + extends WrappedCloseableIterator[U] { private var opened = false // we only retry if it fails on first call when using the iterator private var iter = call(request) + override def innerIterator: Iterator[U] = iter + private def retryIter[V](f: Iterator[U] => V) = { if (!opened) { opened = true diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index 6b8fcde1156..c3c33a85d65 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -47,6 +47,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( private var interrupted = false + // Time at which this sender should finish if the response stream is not finished by then. + private var deadlineTimeMillis = Long.MaxValue + // Signal to wake up when grpcCallObserver.isReady() private val grpcCallObserverReadySignal = new Object @@ -65,6 +68,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( executionObserver.notifyAll() } + // For testing + private[connect] def setDeadline(deadlineMs: Long) = executionObserver.synchronized { + deadlineTimeMillis = deadlineMs + executionObserver.notifyAll() + } + def run(lastConsumedStreamIndex: Long): Unit = { if (executeHolder.reattachable) { // In reattachable execution we use setOnReadyHandler and grpcCallObserver.isReady to control @@ -150,7 +159,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( var finished = false // Time at which this sender should finish if the response stream is not finished by then. - val deadlineTimeMillis = if (!executeHolder.reattachable) { + deadlineTimeMillis = if (!executeHolder.reattachable) { Long.MaxValue } else { val confSize = @@ -232,8 +241,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( assert(finished == false) } else { // If it wasn't sent, time deadline must have been reached before stream became available, - // will exit in the enxt loop iterattion. - assert(deadlineLimitReached) + // or it was intterupted. Will exit in the next loop iterattion. + assert(deadlineLimitReached || interrupted) } } else if (streamFinished) { // Stream is finished and all responses have been sent @@ -301,7 +310,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( val sleepStart = System.nanoTime() var sleepEnd = 0L // Conditions for exiting the inner loop - // 1. was detached + // 1. was interrupted // 2. grpcCallObserver is ready to send more data // 3. time deadline is reached while (!interrupted && diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index d9db07fd228..df0fb3ac3a5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -73,11 +73,16 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: /** The index of the last response produced by execution. */ private var lastProducedIndex: Long = 0 // first response will have index 1 + // For testing + private[connect] var releasedUntilIndex: Long = 0 + /** * Highest response index that was consumed. Keeps track of it to decide which responses needs * to be cached, and to assert that all responses are consumed. + * + * Visible for testing. */ - private var highestConsumedIndex: Long = 0 + private[connect] var highestConsumedIndex: Long = 0 /** * Consumer that waits for available responses. There can be only one at a time, @see @@ -284,6 +289,7 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: responses.remove(i) i -= 1 } + releasedUntilIndex = index } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index bce07133392..974c13b08e3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -183,6 +183,16 @@ private[connect] class ExecuteHolder( } } + // For testing. + private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = synchronized { + grpcResponseSenders.foreach(_.setDeadline(deadlineMs)) + } + + // For testing + private[connect] def interruptGrpcResponseSenders() = synchronized { + grpcResponseSenders.foreach(_.interrupt()) + } + /** * For a short period in ExecutePlan after creation and until runGrpcResponseSender is called, * there is no attached response sender, but yet we start with lastAttachedRpcTime = None, so we diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index ce1f6c93f6c..21f59bdd68e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -71,15 +71,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // The latter is to prevent double execution when a client retries execution, thinking it // never reached the server, but in fact it did, and already got removed as abandoned. if (executions.get(executeHolder.key).isDefined) { - if (getAbandonedTombstone(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", - messageParameters = Map("handle" -> executeHolder.operationId)) - } else { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", - messageParameters = Map("handle" -> executeHolder.operationId)) - } + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", + messageParameters = Map("handle" -> executeHolder.operationId)) + } + if (getAbandonedTombstone(executeHolder.key).isDefined) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", + messageParameters = Map("handle" -> executeHolder.operationId)) } sessionHolder.addExecuteHolder(executeHolder) executions.put(executeHolder.key, executeHolder) @@ -141,12 +140,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging { abandonedTombstones.asMap.asScala.values.toBuffer.toSeq } - private[service] def shutdown(): Unit = executionsLock.synchronized { + private[connect] def shutdown(): Unit = executionsLock.synchronized { scheduledExecutor.foreach { executor => executor.shutdown() executor.awaitTermination(1, TimeUnit.MINUTES) } scheduledExecutor = None + executions.clear() + abandonedTombstones.invalidateAll() + if (!lastExecutionTime.isDefined) { + lastExecutionTime = Some(System.currentTimeMillis()) + } } /** @@ -188,7 +192,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.values.foreach { executeHolder => executeHolder.lastAttachedRpcTime match { case Some(detached) => - if (detached + timeout < nowMs) { + if (detached + timeout <= nowMs) { toRemove += executeHolder } case _ => // execution is active @@ -206,4 +210,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") } + + // For testing. + private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized { + executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) + } + + // For testing. + private[connect] def interruptAllRPCs() = executionsLock.synchronized { + executions.values.foreach(_.interruptGrpcResponseSenders()) + } + + private[connect] def listExecuteHolders = executionsLock.synchronized { + executions.values.toBuffer.toSeq + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala new file mode 100644 index 00000000000..488858d33ea --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.scalatest.concurrent.{Eventually, TimeLimits} +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, WrappedCloseableIterator} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.dsl.MockRemoteSession +import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Base class and utilities for a test suite that starts and tests the real SparkConnectService + * with a real SparkConnectClient, communicating over RPC, but both in-process. + */ +class SparkConnectServerTest extends SharedSparkSession { + + // Server port + val serverPort: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + val eventuallyTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + // Start the real service. + withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { + SparkConnectService.start(spark.sparkContext) + } + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) + } + + override def afterAll(): Unit = { + SparkConnectService.stop() + super.afterAll() + } + + override def beforeEach(): Unit = { + super.beforeEach() + clearAllExecutions() + } + + override def afterEach(): Unit = { + clearAllExecutions() + super.afterEach() + } + + protected def clearAllExecutions(): Unit = { + SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + } + + protected val defaultSessionId = UUID.randomUUID.toString() + protected val defaultUserId = UUID.randomUUID.toString() + + // We don't have the real SparkSession/Dataset api available, + // so use mock for generating simple query plans. + protected val dsl = new MockRemoteSession() + + protected val userContext = proto.UserContext + .newBuilder() + .setUserId(defaultUserId) + .build() + + protected def buildExecutePlanRequest( + plan: proto.Plan, + sessionId: String = defaultSessionId, + operationId: String = UUID.randomUUID.toString) = { + proto.ExecutePlanRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(sessionId) + .setOperationId(operationId) + .setPlan(plan) + .addRequestOptions( + proto.ExecutePlanRequest.RequestOption + .newBuilder() + .setReattachOptions(proto.ReattachOptions.newBuilder().setReattachable(true).build()) + .build()) + .build() + } + + protected def buildReattachExecuteRequest(operationId: String, responseId: Option[String]) = { + val req = proto.ReattachExecuteRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(defaultSessionId) + .setOperationId(operationId) + + if (responseId.isDefined) { + req.setLastResponseId(responseId.get) + } + + req.build() + } + + protected def buildPlan(query: String) = { + proto.Plan.newBuilder().setRoot(dsl.sql(query)).build() + } + + protected def getReattachableIterator( + stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = { + // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: + // GrpcExceptionConverter.convertIterator + stubIterator + .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] + // ExecutePlanResponseReattachableIterator + .innerIterator + .asInstanceOf[ExecutePlanResponseReattachableIterator] + } + + protected def assertNoActiveRpcs(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // nothing running, good + case Right(executions) => + // all rpc detached. + assert( + executions.forall(_.lastAttachedRpcTime.isDefined), + s"Expected no RPCs, but got $executions") + } + } + + protected def assertEventuallyNoActiveRpcs(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveRpcs() + } + } + + protected def assertNoActiveExecutions(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => fail(s"Expected empty, but got $executions") + } + } + + protected def assertEventuallyNoActiveExecutions(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveExecutions() + } + } + + protected def assertExecutionReleased(operationId: String): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => assert(!executions.exists(_.operationId == operationId)) + } + } + + protected def assertEventuallyExecutionReleased(operationId: String): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertExecutionReleased(operationId) + } + } + + // Get ExecutionHolder, assuming that only one execution is active + protected def getExecutionHolder: ExecuteHolder = { + val executions = SparkConnectService.executionManager.listExecuteHolders + assert(executions.length == 1) + executions.head + } + + protected def withClient(f: SparkConnectClient => Unit): Unit = { + val client = SparkConnectClient + .builder() + .port(serverPort) + .sessionId(defaultSessionId) + .userId(defaultUserId) + .enableReattachableExecute() + .build() + try f(client) + finally { + client.shutdown() + } + } + + protected def withRawBlockingStub( + f: proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def withCustomBlockingStub( + retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy())( + f: CustomSparkConnectBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def runQuery(plan: proto.Plan, queryTimeout: Span, iterSleep: Long): Unit = { + withClient { client => + TimeLimits.failAfter(queryTimeout) { + val iter = client.execute(plan) + var operationId: Option[String] = None + var r: proto.ExecutePlanResponse = null + val reattachableIter = getReattachableIterator(iter) + while (iter.hasNext) { + r = iter.next() + operationId match { + case None => operationId = Some(r.getOperationId) + case Some(id) => assert(r.getOperationId == id) + } + if (iterSleep > 0) { + Thread.sleep(iterSleep) + } + } + // Check that last response had ResultComplete indicator + assert(r != null) + assert(r.hasResultComplete) + // ... that client sent ReleaseExecute based on it + assert(reattachableIter.resultComplete) + // ... and that the server released the execution. + assert(operationId.isDefined) + assertEventuallyExecutionReleased(operationId.get) + } + } + } + + protected def runQuery(query: String, queryTimeout: Span, iterSleep: Long = 0): Unit = { + val plan = buildPlan(query) + runQuery(plan, queryTimeout, iterSleep) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala new file mode 100644 index 00000000000..169b15582b6 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.execution + +import java.util.UUID + +import io.grpc.StatusRuntimeException +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.connect.SparkConnectServerTest +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +class ReattachableExecuteSuite extends SparkConnectServerTest { + + // Tests assume that this query will result in at least a couple ExecutePlanResponses on the + // stream. If this is no longer the case because of changes in how much is returned in a single + // ExecutePlanResponse, it may need to be adjusted. + val MEDIUM_RESULTS_QUERY = "select * from range(1000000)" + + test("reattach after initial RPC ends") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + + // open the iterator + iter.next() + // expire all RPCs on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + // iterator should reattach + // (but not necessarily at first next, as there might have been messages buffered client side) + while (iter.hasNext && (reattachableIter.innerIterator eq initialInnerIter)) { + iter.next() + } + assert( + reattachableIter.innerIterator ne initialInnerIter + ) // reattach changed the inner iter + } + } + + test("raw interrupted RPC results in INVALID_CURSOR.DISCONNECTED error") { + withRawBlockingStub { stub => + val iter = stub.executePlan(buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY))) + iter.next() // open the iterator + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("raw new RPC interrupts previous RPC with INVALID_CURSOR.DISCONNECTED error") { + // Raw stub does not have retries, auto reattach etc. + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + iter.next() // open the iterator + + // send reattach + val iter2 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + iter2.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the original iterator + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + // send another reattach + val iter3 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + assert(iter3.hasNext) + iter3.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the previous reattach iterator + val e2 = intercept[StatusRuntimeException] { + while (iter2.hasNext) iter2.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when rpc sender gets interrupted") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + iter.next() + + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when other RPC preempts this one") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + val response = iter.next() + + // Send another Reattach request, it should preempt this request with an + // INVALID_CURSOR.DISCONNECTED error. + withRawBlockingStub { stub => + val reattachIter = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response.getResponseId))) + assert(reattachIter.hasNext) + reattachIter.next() + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + } + + test("abandoned query gets INVALID_HANDLE.OPERATION_ABANDONED error") { + withClient { client => + val plan = buildPlan("select * from range(100000)") + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val operationId = getReattachableIterator(iter).operationId + // open the iterator + iter.next() + // disconnect and remove on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + // check that it throws abandoned error + val e = intercept[SparkException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + // check that afterwards, new operation can't be created with the same operationId. + withCustomBlockingStub() { stub => + val executePlanReq = buildExecutePlanRequest(plan, operationId = operationId) + + val iterNonReattachable = stub.executePlan(executePlanReq) + val eNonReattachable = intercept[SparkException] { + iterNonReattachable.hasNext + } + assert(eNonReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + val iterReattachable = stub.executePlanReattachable(executePlanReq) + val eReattachable = intercept[SparkException] { + iterReattachable.hasNext + } + assert(eReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + } + } + } + + test("client releases responses directly after consuming them") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + assert(iter.hasNext) // open iterator + val execution = getExecutionHolder + assert(execution.responseObserver.releasedUntilIndex == 0) + + // get two responses, check on the server that ReleaseExecute releases them afterwards + val response1 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 1) + } + + val response2 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 2) + } + + withRawBlockingStub { stub => + // Reattach after response1 should fail with INVALID_CURSOR.POSITION_NOT_AVAILABLE + val reattach1 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response1.getResponseId))) + val e = intercept[StatusRuntimeException] { + reattach1.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Reattach after response2 should work + val reattach2 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response2.getResponseId))) + val response3 = reattach2.next() + val response4 = reattach2.next() + val response5 = reattach2.next() + + // The original client iterator will handle the INVALID_CURSOR.DISCONNECTED error, + // and reconnect back. Since the raw iterator was not releasing responses, client iterator + // should be able to continue where it left off (server shouldn't have released yet) + assert(execution.responseObserver.releasedUntilIndex == 2) + assert(iter.hasNext) + + val r3 = iter.next() + assert(r3.getResponseId == response3.getResponseId) + val r4 = iter.next() + assert(r4.getResponseId == response4.getResponseId) + val r5 = iter.next() + assert(r5.getResponseId == response5.getResponseId) + // inner iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + } + + test("server releases responses automatically when client moves ahead") { + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + var lastSeenResponse: String = null + + iter.hasNext // open iterator + val execution = getExecutionHolder + + // after consuming enough from the iterator, server should automatically start releasing + var lastSeenIndex = 0 + while (iter.hasNext && execution.responseObserver.releasedUntilIndex == 0) { + val r = iter.next() + lastSeenResponse = r.getResponseId() + lastSeenIndex += 1 + } + assert(iter.hasNext) + assert(execution.responseObserver.releasedUntilIndex > 0) + + // Reattach from the beginning is not available. + val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + val e = intercept[StatusRuntimeException] { + reattach.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Original iterator got disconnected by the reattach and gets INVALID_CURSOR.DISCONNECTED + val e2 = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + Eventually.eventually(timeout(eventuallyTimeout)) { + // Even though we didn't consume more from the iterator, the server thinks that + // it sent more, because GRPC stream onNext() can push into internal GRPC buffer without + // client picking it up. + assert(execution.responseObserver.highestConsumedIndex > lastSeenIndex) + } + // but CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE is big enough that the last + // response we've seen is still in range + assert(execution.responseObserver.releasedUntilIndex < lastSeenIndex) + + // and a new reattach can continue after what there. + val reattach2 = + stub.reattachExecute(buildReattachExecuteRequest(operationId, Some(lastSeenResponse))) + assert(reattach2.hasNext) + while (reattach2.hasNext) reattach2.next() + } + } + + // A few integration tests with large results. + // They should run significantly faster than the LARGE_QUERY_TIMEOUT + // - big query (4 seconds, 871 milliseconds) + // - big query and slow client (7 seconds, 288 milliseconds) + // - big query with frequent reattach (1 second, 527 milliseconds) + // - big query with frequent reattach and slow client (7 seconds, 365 milliseconds) + // - long sleeping query (10 seconds, 805 milliseconds) + + // intentionally smaller than CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, + // so that reattach deadline doesn't "unstuck" if something got stuck. + val LARGE_QUERY_TIMEOUT = 100.seconds + + val LARGE_RESULTS_QUERY = s"select id, " + + (1 to 20).map(i => s"cast(id as string) c$i").mkString(", ") + + s" from range(1000000)" + + test("big query") { + // regular query with large results + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query and slow client") { + // regular query with large results, but client is slow so sender will need to control flow + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query with frequent reattach") { + // will reattach every 100kB + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("big query with frequent reattach and slow client") { + // will reattach every 100kB, and in addition the client is slow, + // so sender will need to control flow + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("long sleeping query") { + // query will be sleeping and not returning results, while having multiple reattach + withSparkEnvConfs( + (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { + runQuery("select sleep(10000) as s", 30.seconds) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index f5819b95087..1163088c82a 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -302,6 +302,30 @@ abstract class SparkFunSuite } } + /** + * Sets all configurations specified in `pairs` in SparkEnv SparkConf, calls `f`, and then + * restores all configurations. + */ + protected def withSparkEnvConfs(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.getOption(key).isDefined) { + Some(conf.get(key)) + } else { + None + } + } + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.remove(key) + } + } + } + /** * Checks an exception with an error class against expected results. * @param exception The exception to check --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org