This is an automated email from the ASF dual-hosted git repository. yangjie01 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f9c8c7a3312 [SPARK-45871][CONNECT] Optimizations collection conversion related to `.toBuffer` in the `connect` modules f9c8c7a3312 is described below commit f9c8c7a3312e533afb95e527ca0c451148bad6a4 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Sun Nov 12 13:45:09 2023 +0800 [SPARK-45871][CONNECT] Optimizations collection conversion related to `.toBuffer` in the `connect` modules ### What changes were proposed in this pull request? This PR includes the following optimizations related to `.toBuffer` in the `connect`s module: 1. For the two functions `sql(String, java.util.Map[String, Any]): DataFrame` and `sql(String, Array[_]): DataFrame` in `SparkSession`, the approach of using `.find` directly on the `CloseableIterator` to locate the target and utilizing `.foreach` to consume the remaining elements replaced the previous method of converting to a collection using `toBuffer.toSeq` and then searching for the target using `.find`. This approach avoids the need for an unnecessary collection creation. 2. For function `execute(proto.Relation.Builder => Unit): Unit` in `SparkSession`, as no elements are returned, `.foreach` is used instead of `.toBuffer` to avoid an unnecessary collection creation. 3. For function `execute(command: proto.Command): Seq[ExecutePlanResponse]` in `SparkSession`, in Scala 2.12, `s.c.TraversableOnce#toSeq` returns an `immutable.Stream`, which is a tail-lazy structure that may not consume all elements. Therefore, it is necessary to use `s.c.TraversableOnce#toBuffer` for materialization. However, in Scala 2.13, `s.c.IterableOnceOps#toSeq` constructs an `immutable.Seq`, which is not a lazy data structure and ensures consumption of all elements. Therefore [...] 4. The optimizations for the two functions `listAbandonedExecutions: Seq[ExecuteInfo]` and `listExecuteHolders` in `SparkConnectExecutionManager` are consistent with item 3 above. Additionally, to prevent helper function used for testing from creating copies, a `private[connect]` scope helper function was added to the companion object of `ExecutePlanResponseReattachableIterator`. ### Why are the changes needed? Avoid unnecessary collection copies ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Action - Added new test cases to demonstrate that both `.toSeq` and `.foreach` can consume all elements in `ResponseReattachableIterator` in Scala 2.13. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43745 from LuciferYang/SPARK-45871. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: yangjie01 <yangji...@baidu.com> --- .../scala/org/apache/spark/sql/SparkSession.scala | 48 ++++++++++++---------- .../connect/client/SparkConnectClientSuite.scala | 48 ++++++++++++++++++++++ .../ExecutePlanResponseReattachableIterator.scala | 10 +++++ .../service/SparkConnectExecutionManager.scala | 6 +-- .../spark/sql/connect/SparkConnectServerTest.scala | 12 +----- 5 files changed, 90 insertions(+), 34 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 34756f9a440..ca692d2d4f8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -250,15 +250,18 @@ class SparkSession private[sql] ( .setSql(sqlText) .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - // .toBuffer forces that the iterator is consumed and closed - val responseSeq = client.execute(plan.build()).toBuffer.toSeq + val responseIter = client.execute(plan.build()) - val response = responseSeq - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } } /** @@ -309,15 +312,18 @@ class SparkSession private[sql] ( .setSql(sqlText) .putAllNamedArguments(args.asScala.view.mapValues(lit(_).expr).toMap.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - // .toBuffer forces that the iterator is consumed and closed - val responseSeq = client.execute(plan.build()).toBuffer.toSeq - - val response = responseSeq - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) + val responseIter = client.execute(plan.build()) + + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } } /** @@ -543,14 +549,14 @@ class SparkSession private[sql] ( f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - // .toBuffer forces that the iterator is consumed and closed - client.execute(plan).toBuffer + // .foreach forces that the iterator is consumed and closed + client.execute(plan).foreach(_ => ()) } private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - // .toBuffer forces that the iterator is consumed and closed - client.execute(plan).toBuffer.toSeq + // .toSeq forces that the iterator is consumed and closed + client.execute(plan).toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index d0c85da5f21..b93713383b2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -387,6 +387,54 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } assert(dummyFn.counter == 2) } + + test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .enableReattachableExecute() + .build() + val session = SparkSession.builder().client(client).create() + val cmd = session.newCommand(b => + b.setSqlCommand( + proto.SqlCommand + .newBuilder() + .setSql("select * from range(10000000)"))) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val iter = client.execute(plan.build()) + val reattachableIter = + ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.toSeq + // In several places in SparkSession, we depend on `.toSeq` to consume and close the iterator. + // If this assertion fails, we need to double check the correctness of that. + // In scala 2.12 `s.c.TraversableOnce#toSeq` builds an `immutable.Stream`, + // which is a tail lazy structure and this would fail. + // In scala 2.13 `s.c.IterableOnceOps#toSeq` builds an `immutable.Seq` which is not + // lazy and will consume and close the iterator. + assert(reattachableIter.resultComplete) + } + + test("SPARK-45871: Client execute iterator.foreach consumes the reattachable iterator") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .enableReattachableExecute() + .build() + val session = SparkSession.builder().client(client).create() + val cmd = session.newCommand(b => + b.setSqlCommand( + proto.SqlCommand + .newBuilder() + .setSql("select * from range(10000000)"))) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val iter = client.execute(plan.build()) + val reattachableIter = + ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.foreach(_ => ()) + assert(reattachableIter.resultComplete) + } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 2b61463c343..cfa492ef063 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -326,3 +326,13 @@ class ExecutePlanResponseReattachableIterator( private def retry[T](fn: => T): T = GrpcRetryHandler.retry(retryPolicy)(fn) } + +private[connect] object ExecutePlanResponseReattachableIterator { + @scala.annotation.tailrec + private[connect] def fromIterator( + iter: Iterator[proto.ExecutePlanResponse]): ExecutePlanResponseReattachableIterator = + iter match { + case e: ExecutePlanResponseReattachableIterator => e + case w: WrappedCloseableIterator[proto.ExecutePlanResponse] => fromIterator(w.innerIterator) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index c004358e1cf..36c6f73329b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -153,7 +153,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * cache, and the tombstones will be eventually removed. */ def listAbandonedExecutions: Seq[ExecuteInfo] = { - abandonedTombstones.asMap.asScala.values.toBuffer.toSeq + abandonedTombstones.asMap.asScala.values.toSeq } private[connect] def shutdown(): Unit = executionsLock.synchronized { @@ -236,7 +236,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.values.foreach(_.interruptGrpcResponseSenders()) } - private[connect] def listExecuteHolders = executionsLock.synchronized { - executions.values.toBuffer.toSeq + private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized { + executions.values.toSeq } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index c4a5539ce0b..1c0d9a68ab6 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -27,7 +27,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, SparkConnectStubState, WrappedCloseableIterator} +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, SparkConnectStubState} import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.connect.config.Connect @@ -147,15 +147,7 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def getReattachableIterator( stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = { - // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: - // GrpcExceptionConverter.convertIterator - stubIterator - .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] - .innerIterator - .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] - // ExecutePlanResponseReattachableIterator - .innerIterator - .asInstanceOf[ExecutePlanResponseReattachableIterator] + ExecutePlanResponseReattachableIterator.fromIterator(stubIterator) } protected def assertNoActiveRpcs(): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org