This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f6999df0c7f0 [SPARK-47081][CONNECT] Support Query Execution Progress f6999df0c7f0 is described below commit f6999df0c7f0bb18778b29ebdbe9f7d40899808a Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Thu Apr 4 12:59:56 2024 +0800 [SPARK-47081][CONNECT] Support Query Execution Progress ### What changes were proposed in this pull request? This patch adss a new mechanism to push query execution progress for batch queries. We add a new response message type and periodically push query progress to the client. The client can consume this data to for example display a progress bar. This patch adds support for displaying a progress bar in the PySpark shell when started with Spark Connect. The proto message is defined as follows: ``` // This message is used to communicate progress about the query progress during the execution. // This message is used to communicate progress about the query progress during the execution. message ExecutionProgress { // Captures the progress of each individual stage. repeated StageInfo stages = 1; // Captures the currently in progress tasks. int64 num_inflight_tasks = 2; message StageInfo { int64 stage_id = 1; int64 num_tasks = 2; int64 num_completed_tasks = 3; int64 input_bytes_read = 4; bool done = 5; } } ``` Clients can simply ignore the messages or consume them. On top of that this adds additional capabilities to register a callback for progress tracking to the SparkSession. ``` handler = lambda **kwargs: print(kwargs) spark.register_progress_handler(handler) spark.range(100).collect() spark.remove_progress_handler(handler) ``` #### Example 1 ![progress_medium_query_multi_stage mp4](https://github.com/apache/spark/assets/3421/5eff1ec4-def2-4d39-8a75-13a6af784c99) #### Example 2 ![progress_bar mp4](https://github.com/apache/spark/assets/3421/20638511-2da4-4bd6-83f2-da3b9f500bde) ### Why are the changes needed? Usability and Experience ### Does this PR introduce _any_ user-facing change? When the user opens the PySpark shell with Spark Connect mode, it will use the progress bar by default. ### How was this patch tested? Added new tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45150 from grundprinzip/SPARK-47081. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/SparkSessionE2ESuite.scala | 10 + .../src/main/protobuf/spark/connect/base.proto | 22 ++- .../spark/sql/connect/client/SparkResult.scala | 39 ++++ .../apache/spark/sql/connect/config/Connect.scala | 8 + .../ConnectProgressExecutionListener.scala | 191 +++++++++++++++++++ .../execution/ExecuteGrpcResponseSender.scala | 51 ++++- .../execution/ExecuteResponseObserver.scala | 11 +- .../connect/execution/ExecuteThreadRunner.scala | 5 +- .../sql/connect/service/SparkConnectService.scala | 5 + .../ConnectProgressExecutionListenerSuite.scala | 156 ++++++++++++++++ .../org/apache/spark/deploy/SparkSubmit.scala | 7 +- dev/sparktestsupport/modules.py | 1 + .../source/reference/pyspark.sql/spark_session.rst | 13 +- python/pyspark/shell.py | 21 ++- python/pyspark/sql/connect/client/core.py | 55 +++++- python/pyspark/sql/connect/proto/base_pb2.py | 208 +++++++++++---------- python/pyspark/sql/connect/proto/base_pb2.pyi | 82 +++++++- python/pyspark/sql/connect/session.py | 16 ++ python/pyspark/sql/connect/shell/__init__.py | 26 +++ python/pyspark/sql/connect/shell/progress.py | 187 ++++++++++++++++++ python/pyspark/sql/session.py | 61 ++++++ python/pyspark/sql/tests/connect/shell/__init__.py | 16 ++ .../sql/tests/connect/shell/test_progress.py | 111 +++++++++++ .../sql/tests/connect/test_connect_session.py | 21 +++ 24 files changed, 1204 insertions(+), 119 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index e4cbcf620d15..b967245d90c2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -229,6 +229,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } + test("progress is available for the spark result") { + val result = spark + .range(10000) + .repartition(1000) + .collectResult() + assert(result.length == 10000) + assert(result.progress.stages.map(_.numTasks).sum > 100) + assert(result.progress.stages.map(_.completedTasks).sum > 100) + } + test("interrupt operation") { val session = spark import session.implicits._ diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 9a9121d84f76..49a33d3419b6 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -333,7 +333,7 @@ message ExecutePlanRequest { // The response of a query, can be one or more for each request. Responses belonging to the // same input query, carry the same `session_id`. -// Next ID: 16 +// Next ID: 17 message ExecutePlanResponse { string session_id = 1; // Server-side generated idempotency key that the client can use to assert that the server side @@ -378,6 +378,9 @@ message ExecutePlanResponse { // Response for command that creates ResourceProfile. CreateResourceProfileCommandResult create_resource_profile_command_result = 17; + // (Optional) Intermediate query progress reports. + ExecutionProgress execution_progress = 18; + // Support arbitrary result objects. google.protobuf.Any extension = 999; } @@ -438,6 +441,23 @@ message ExecutePlanResponse { // the execution is complete. If the server sends onComplete without sending a ResultComplete, // it means that there is more, and the client should use ReattachExecute RPC to continue. } + + // This message is used to communicate progress about the query progress during the execution. + message ExecutionProgress { + // Captures the progress of each individual stage. + repeated StageInfo stages = 1; + + // Captures the currently in progress tasks. + int64 num_inflight_tasks = 2; + + message StageInfo { + int64 stage_id = 1; + int64 num_tasks = 2; + int64 num_completed_tasks = 3; + int64 input_bytes_read = 4; + bool done = 5; + } + } } // The key-value pair for the config request and response. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 7a7c6a2d6c92..93d1075aea02 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -20,6 +20,7 @@ import java.lang.ref.Cleaner import java.util.Objects import scala.collection.mutable +import scala.jdk.CollectionConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} @@ -40,6 +41,38 @@ private[sql] class SparkResult[T]( timeZoneId: String) extends AutoCloseable { self => + case class StageInfo( + stageId: Long, + numTasks: Long, + completedTasks: Long = 0, + inputBytesRead: Long = 0, + completed: Boolean = false) + + object StageInfo { + def apply(stageInfo: proto.ExecutePlanResponse.ExecutionProgress.StageInfo): StageInfo = { + StageInfo( + stageInfo.getStageId, + stageInfo.getNumTasks, + stageInfo.getNumCompletedTasks, + stageInfo.getInputBytesRead, + stageInfo.getDone) + } + } + + object Progress { + def apply(progress: proto.ExecutePlanResponse.ExecutionProgress): Progress = { + Progress( + progress.getStagesList.asScala.map(StageInfo(_)).toSeq, + progress.getNumInflightTasks) + } + } + + /** + * Progress of the query execution. This information can be accessed from the iterator. + */ + case class Progress(stages: Seq[StageInfo], inflight: Long) + + var progress: Progress = new Progress(Seq.empty, 0) private[this] var opId: String = _ private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ @@ -97,6 +130,12 @@ private[sql] class SparkResult[T]( } stop |= stopOnOperationId + // Update the execution status. This information can now be accessed directly from + // the iterator. + if (response.hasExecutionProgress) { + progress = Progress(response.getExecutionProgress) + } + if (response.hasSchema) { // The original schema should arrive before ArrowBatches. structType = diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 39bf1a630af6..6ba100af1bb9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -265,4 +265,12 @@ object Connect { .version("4.0.0") .bytesConf(ByteUnit.BYTE) .createWithDefault(1024) + + val CONNECT_PROGRESS_REPORT_INTERVAL = + buildConf("spark.connect.progress.reportInterval") + .doc("The interval at which the progress of a query is reported to the client." + + " If the value is set to a negative value the progress reports will be disabled.") + .version("4.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("2s") } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala new file mode 100644 index 000000000000..954956363505 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import org.apache.spark.connect.proto.ExecutePlanResponse +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} + +/** + * A listener that tracks the execution of jobs and stages for a given set of tags. This is used + * to track the progress of a job that is being executed through the connect API. + * + * The listener is instantiated once for the SparkConnectService and then used to track all the + * current query executions. + */ +private[connect] class ConnectProgressExecutionListener extends SparkListener with Logging { + + /** + * A tracker for a given tag. This is used to track the progress of an operation is being + * executed through the connect API. + */ + class ExecutionTracker(val tag: String) { + + class StageInfo( + val stageId: Int, + var numTasks: Int, + var completedTasks: Int = 0, + var inputBytesRead: Long = 0, + var completed: Boolean = false) { + + val lock = new Object + def update(i: StageInfo => Unit): Unit = { + lock.synchronized { + i(this) + } + } + + def toProto(): ExecutePlanResponse.ExecutionProgress.StageInfo = { + ExecutePlanResponse.ExecutionProgress.StageInfo + .newBuilder() + .setStageId(stageId) + .setNumTasks(numTasks) + .setNumCompletedTasks(completedTasks) + .setInputBytesRead(inputBytesRead) + .setDone(completed) + .build() + } + } + + // The set of jobs that are being tracked by this tracker. We always only add to this list + // but never remove. This is to avoid concurrency issues. + private[ConnectProgressExecutionListener] var jobs: Set[Int] = Set() + // The set of stages that are being tracked by this tracker. We always only add to this list + // but never remove. This is to avoid concurrency issues. + private[ConnectProgressExecutionListener] var stages: Map[Int, StageInfo] = Map.empty + // The tracker is marked as dirty if it has new progress to report. + private[ConnectProgressExecutionListener] val dirty = new AtomicBoolean(false) + // Tracks all currently running tasks for a particular tracker. + private[ConnectProgressExecutionListener] val inFlightTasks = new AtomicInteger(0) + + /** + * Yield the current state of the tracker if it is dirty. A consumer of the tracker can + * provide a callback that will be called with the current state of the tracker if the tracker + * has new progress to report. + * + * If the tracker was marked as dirty, the state is reset after. + */ + def yieldWhenDirty(thunk: (Seq[StageInfo], Long) => Unit): Unit = { + if (dirty.get()) { + thunk(stages.values.toSeq, inFlightTasks.get()) + dirty.set(false) + } + } + + /** + * Add a job to the tracker. This will add the job to the list of jobs that are being tracked + */ + def addJob(job: SparkListenerJobStart): Unit = synchronized { + jobs = jobs + job.jobId + job.stageInfos.foreach { stage => + stages = stages + (stage.stageId -> new StageInfo(stage.stageId, stage.numTasks)) + } + dirty.set(true) + } + + def jobCount(): Int = { + jobs.size + } + + def stageCount(): Int = { + stages.size + } + } + + val trackedTags = collection.concurrent.TrieMap[String, ExecutionTracker]() + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val tags = jobStart.properties.getProperty("spark.job.tags") + if (tags != null) { + val thisJobTags = tags.split(",").map(_.trim).toSet + thisJobTags.foreach { tag => + trackedTags.get(tag).foreach { tracker => + tracker.addJob(jobStart) + } + } + } + } + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + // Check if the task belongs to a job that we are tracking. + trackedTags.foreach({ case (_, tracker) => + if (tracker.stages.contains(taskStart.stageId)) { + tracker.inFlightTasks.incrementAndGet() + tracker.dirty.set(true) + } + }) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + // Check if the task belongs to a job that we are tracking. + trackedTags.foreach({ case (_, tracker) => + if (tracker.stages.contains(taskEnd.stageId)) { + tracker.stages.get(taskEnd.stageId).foreach { stage => + stage.update { i => + i.completedTasks += 1 + i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead + } + } + // This should never become negative, simply reset to zero if it does. + tracker.inFlightTasks.decrementAndGet() + if (tracker.inFlightTasks.get() < 0) { + tracker.inFlightTasks.set(0) + } + tracker.dirty.set(true) + } + }) + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + trackedTags.foreach({ case (_, tracker) => + if (tracker.stages.contains(stageCompleted.stageInfo.stageId)) { + tracker.stages(stageCompleted.stageInfo.stageId).update { stage => + stage.completed = true + } + tracker.dirty.set(true) + } + }) + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + trackedTags.foreach({ case (_, tracker) => + if (tracker.jobs.contains(jobEnd.jobId)) { + tracker.dirty.set(true) + } + }) + } + + def tryGetTracker(tag: String): Option[ExecutionTracker] = { + trackedTags.get(tag) + } + + def registerJobTag(tag: String): Unit = { + trackedTags += tag -> new ExecutionTracker(tag) + } + + def removeJobTag(tag: String): Unit = { + trackedTags -= tag + } + + def clearJobTags(): Unit = { + trackedTags.clear() + } + +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index c9ceef969e29..a9444862b3aa 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.connect.execution +import scala.jdk.CollectionConverters._ + import com.google.protobuf.Message import io.grpc.stub.{ServerCallStreamObserver, StreamObserver} import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.common.ProtoUtils -import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE} +import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE, CONNECT_PROGRESS_REPORT_INTERVAL} import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -131,6 +134,38 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( } } + /** + * This method is called repeatedly during the query execution to enqueue a new message to be + * send to the client about the current query progress. The message is not directly send to the + * client, but rather enqueued to in the response observer. + */ + private def enqueueProgressMessage(): Unit = { + if (executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) > 0) { + SparkConnectService.executionListener.foreach { listener => + // It is possible, that the tracker is no longer available and in this + // case we simply ignore it and do not send any progress message. This avoids + // having to synchronize on the listener. + listener.tryGetTracker(executeHolder.jobTag).foreach { tracker => + // Only send progress message if there is something new to report. + tracker.yieldWhenDirty { (stages, inflightTasks) => + val response = ExecutePlanResponse + .newBuilder() + .setExecutionProgress( + ExecutePlanResponse.ExecutionProgress + .newBuilder() + .addAllStages(stages.map(_.toProto()).asJava) + .setNumInflightTasks(inflightTasks)) + .build() + // There is a special case when the response observer has alreaady determined + // that the final message is send (and the stream will be closed) but we might want + // to send the progress message. In this case we ignore the result of the `onNext` call. + executeHolder.responseObserver.tryOnNext(response) + } + } + } + } + } + /** * Attach to the executionObserver, consume responses from it, and send them to grpcObserver. * @@ -173,6 +208,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( var sentResponsesSize: Long = 0 while (!finished) { + enqueueProgressMessage() var response: Option[CachedStreamResponse[T]] = None // Conditions for exiting the inner loop (and helpers to compute them): @@ -201,9 +237,18 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( // The state of interrupted, response and lastIndex are changed under executionObserver // monitor, and will notify upon state change. if (response.isEmpty) { - val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) + // Wake up more frequently to send the progress updates. + val progressTimeout = + executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) + // If the progress feature is disabled, wait for the deadline. + val timeout = if (progressTimeout > 0) { + progressTimeout + } else { + Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) + } logTrace(s"Wait for response to become available with timeout=$timeout ms.") executionObserver.responseLock.wait(timeout) + enqueueProgressMessage() logTrace(s"Reacquired executionObserver lock after waiting.") sleepEnd = System.nanoTime() } @@ -228,6 +273,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( s"waitingForResults=${consumeSleep}ns waitingForSend=${sendSleep}ns") throw new SparkSQLException(errorClass = "INVALID_CURSOR.DISCONNECTED", Map.empty) } else if (gotResponse) { + enqueueProgressMessage() // There is a response available to be sent. val sent = sendResponse(response.get, deadlineTimeMillis) if (sent) { @@ -240,6 +286,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( assert(deadlineLimitReached || interrupted) } } else if (streamFinished) { + enqueueProgressMessage() // Stream is finished and all responses have been sent logInfo( s"Stream finished for opId=${executeHolder.operationId}, " + diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index a7877503f461..92c23c6165d2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -107,9 +107,9 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: 0 } - def onNext(r: T): Unit = responseLock.synchronized { + def tryOnNext(r: T): Boolean = responseLock.synchronized { if (finalProducedIndex.nonEmpty) { - throw new IllegalStateException("Stream onNext can't be called after stream completed") + return false } lastProducedIndex += 1 val processedResponse = setCommonResponseFields(r) @@ -127,6 +127,13 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: s"Execution opId=${executeHolder.operationId} produced response " + s"responseId=${responseId} idx=$lastProducedIndex") responseLock.notifyAll() + true + } + + def onNext(r: T): Unit = { + if (!tryOnNext(r)) { + throw new IllegalStateException("Stream onNext can't be called after stream completed") + } } def onError(t: Throwable): Unit = responseLock.synchronized { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 41146e4ef688..56776819dac9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag} +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils import org.apache.spark.util.Utils @@ -123,6 +123,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } } finally { executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) + SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag)) executeHolder.sparkSessionTags.foreach { tag => executeHolder.sessionHolder.session.sparkContext.removeJobTag( ExecuteSessionTag( @@ -158,6 +159,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // Set tag for query cancellation session.sparkContext.addJobTag(executeHolder.jobTag) + // Register the job for progress reports. + SparkConnectService.executionListener.foreach(_.registerJobTag(executeHolder.jobTag)) // Also set all user defined tags as Spark Job tags. executeHolder.sparkSessionTags.foreach { tag => session.sparkContext.addJobTag( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 9324e8e6c5f1..476254bc6e39 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -38,6 +38,7 @@ import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} +import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils import org.apache.spark.status.ElementTrackingStore @@ -284,6 +285,7 @@ object SparkConnectService extends Logging { private[connect] var uiTab: Option[SparkConnectServerTab] = None private[connect] var listener: SparkConnectServerListener = _ + private[connect] var executionListener: Option[ConnectProgressExecutionListener] = None // For testing purpose, it's package level private. private[connect] def localPort: Int = { @@ -330,6 +332,9 @@ object SparkConnectService extends Logging { } else { None } + // Add the execution listener needed for query progress. + executionListener = Some(new ConnectProgressExecutionListener) + sc.addSparkListener(executionListener.get) } /** diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala new file mode 100644 index 000000000000..43e978a18f1f --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.execution + +import java.util.Properties + +import org.mockito.Mockito.when +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkFunSuite, Success} +import org.apache.spark.executor.{ExecutorMetrics, InputMetrics, TaskMetrics} +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart, StageInfo, TaskInfo} + +class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSugar { + + def mockStage(stageId: Int, numTasks: Int): StageInfo = { + val result = mock[StageInfo] + when(result.stageId).thenReturn(stageId) + when(result.numTasks).thenReturn(numTasks) + result + } + + val testTag = "testTag" + val testStage1 = mockStage(1, 1) + val testStage2 = mockStage(2, 1) + + val testStage1Task1 = mock[TaskInfo] + val testStage1Task1ExecutorMetrics = mock[ExecutorMetrics] + val testStage1Task1Metrics = mock[TaskMetrics] + + val inputMetrics = mock[InputMetrics] + when(inputMetrics.bytesRead).thenReturn(500) + when(testStage1Task1Metrics.inputMetrics).thenReturn(inputMetrics) + + val testStage2Task1 = mock[TaskInfo] +// + val testProperties = new Properties() + testProperties.setProperty("spark.job.tags", s"otherTag,$testTag,anotherTag") + + val testJobStart = SparkListenerJobStart(1, 1, Seq(testStage1, testStage2), testProperties) + val testTaskStart = SparkListenerTaskStart(1, 1, testStage1Task1) + + test("onJobStart with no matching tags") { + val listener = new ConnectProgressExecutionListener + listener.onJobStart(testJobStart) + assert(listener.trackedTags.isEmpty) + } + + test("onJobStart with a registered tag") { + val listener = new ConnectProgressExecutionListener + listener.registerJobTag(testTag) + assert(listener.trackedTags.size == 1) + + // Trigger the event + listener.onJobStart(testJobStart) + val t = listener.trackedTags(testTag) + + t.yieldWhenDirty((stages, inflight) => { + assert(stages.map(_.numTasks).sum == 2) + assert(stages.map(_.completedTasks).sum == 0) + assert(stages.size == 2) + assert(stages.map(_.inputBytesRead).sum == 0) + assert(inflight == 0) + }) + } + + test("taskDone") { + val listener = new ConnectProgressExecutionListener + listener.registerJobTag(testTag) + listener.onJobStart(testJobStart) + + // Finish the tasks + val taskEnd = SparkListenerTaskEnd( + 1, + 1, + "taskType", + Success, + testStage1Task1, + testStage1Task1ExecutorMetrics, + testStage1Task1Metrics) + + val t = listener.trackedTags(testTag) + var yielded = false + t.yieldWhenDirty { (stages, inflight) => + assert(stages.map(_.numTasks).sum == 2) + assert(stages.map(_.completedTasks).sum == 0) + assert(stages.size == 2) + assert( + stages + .map(_.completed match { + case true => 1 + case false => 0 + }) + .sum == 0) + yielded = true + } + assert(yielded, "Must updated with results") + + yielded = false + listener.onTaskEnd(taskEnd) + t.yieldWhenDirty { (stages, inflight) => + assert(stages.map(_.numTasks).sum == 2) + assert(stages.map(_.completedTasks).sum == 1) + assert(stages.size == 2) + assert(stages.map(_.inputBytesRead).sum == 500) + assert( + stages + .map(_.completed match { + case true => 1 + case false => 0 + }) + .sum == 0) + yielded = true + } + assert(yielded, "Must updated with results") + yielded = false + t.yieldWhenDirty { (stages, inflight) => + yielded = true + } + assert(!yielded, "Must not update if not dirty") + + val stageEnd = SparkListenerStageCompleted(testStage1) + listener.onStageCompleted(stageEnd) + t.yieldWhenDirty { (stages, inflight) => + assert(stages.map(_.numTasks).sum == 2) + assert(stages.map(_.completedTasks).sum == 1) + assert(stages.size == 2) + assert(stages.map(_.inputBytesRead).sum == 500) + assert( + stages + .map(_.completed match { + case true => 1 + case false => 0 + }) + .sum == 1) + yielded = true + } + assert(yielded, "Must updated with results") + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 9ab394741a82..c0df74f8d0cc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -744,8 +744,11 @@ private[spark] class SparkSubmit extends Logging { } } - // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. - if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. Except, + // when Spark Connect is in local mode, because Spark Connect support its own progress + // reporting. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS) && + !sparkConf.contains("spark.local.connect")) { sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 6b087436c687..d3ffa79ebe68 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1060,6 +1060,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg", "pyspark.sql.tests.connect.test_parity_pandas_udf_window", "pyspark.sql.tests.connect.test_resources", + "pyspark.sql.tests.connect.shell.test_progress", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index ea71249e292e..4e679da59c16 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -78,12 +78,15 @@ Spark Connect Only SparkSession.addArtifact SparkSession.addArtifacts - SparkSession.copyFromLocalToFs + SparkSession.addTag + SparkSession.clearProgressHandlers + SparkSession.clearTags SparkSession.client + SparkSession.copyFromLocalToFs + SparkSession.getTags SparkSession.interruptAll - SparkSession.interruptTag SparkSession.interruptOperation - SparkSession.addTag + SparkSession.interruptTag + SparkSession.registerProgressHandler + SparkSession.removeProgressHandler SparkSession.removeTag - SparkSession.getTags - SparkSession.clearTags diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index f705f0edd8fe..12ff86ecc9ff 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -45,11 +45,30 @@ if getattr(builtins, "__IPYTHON__", False): if parent_dir in sys.path: sys.path.remove(parent_dir) - if is_remote(): try: # Creates pyspark.sql.connect.SparkSession. spark = SparkSession.builder.getOrCreate() + + from pyspark.sql.connect.shell import PROGRESS_BAR_ENABLED + + # Check if th eprogress bar needs to be disabled. + if PROGRESS_BAR_ENABLED not in os.environ: + os.environ[PROGRESS_BAR_ENABLED] = "1" + else: + val = os.getenv(PROGRESS_BAR_ENABLED, "false") + if val.lower().strip() == "false": + os.environ[PROGRESS_BAR_ENABLED] = "0" + elif val.lower().strip() == "true": + os.environ[PROGRESS_BAR_ENABLED] = "1" + + val = os.environ[PROGRESS_BAR_ENABLED] + if val not in ("1", "0"): + raise ValueError( + f"Environment variable '{PROGRESS_BAR_ENABLED}' must " + f"be set to either 1 or 0, found: {val}" + ) + except Exception: import sys import traceback diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index b3807d80f6c9..17b5d99aba94 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -93,6 +93,7 @@ from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.util import PythonEvalType from pyspark.storagelevel import StorageLevel from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError +from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto if TYPE_CHECKING: from google.rpc.error_details_pb2 import ErrorInfo @@ -694,6 +695,37 @@ class SparkConnectClient(object): self._profiler_collector = ConnectProfilerCollector() + self._progress_handlers: List[ProgressHandler] = [] + + def register_progress_handler(self, handler: ProgressHandler) -> None: + """ + Register a progress handler to be called when a progress message is received. + + Parameters + ---------- + handler : ProgressHandler + The callable that will be called with the progress information. + + """ + if handler in self._progress_handlers: + return + self._progress_handlers.append(handler) + + def clear_progress_handlers(self) -> None: + self._progress_handlers.clear() + + def remove_progress_handler(self, handler: ProgressHandler) -> None: + """ + Remove a progress handler from the list of registered handlers. + + Parameters + ---------- + handler : ProgressHandler + The callable to remove from the list of progress handlers. + + """ + self._progress_handlers.remove(handler) + def _retrying(self) -> "Retrying": return Retrying(self._retry_policies) @@ -1213,7 +1245,10 @@ class SparkConnectClient(object): self._handle_error(error) def _execute_and_fetch_as_iterator( - self, req: pb2.ExecutePlanRequest, observations: Dict[str, Observation] + self, + req: pb2.ExecutePlanRequest, + observations: Dict[str, Observation], + progress: Optional["Progress"] = None, ) -> Iterator[ Union[ "pa.RecordBatch", @@ -1292,6 +1327,10 @@ class SparkConnectClient(object): yield {"get_resources_command_result": resources} if b.HasField("extension"): yield b.extension + if b.HasField("execution_progress"): + if progress: + p = from_proto(b.execution_progress) + progress.update_ticks(*p) if b.HasField("arrow_batch"): logger.debug( f"Received arrow batch rows={b.arrow_batch.row_count} " @@ -1338,6 +1377,16 @@ class SparkConnectClient(object): with attempt: for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): yield from handle_response(b) + except KeyboardInterrupt: + logger.debug(f"Interrupt request received for operation={req.operation_id}") + try: + self.interrupt_operation(req.operation_id) + except Exception as e: + # Swallow all errors if aborted. + logger.debug(f"Caught an error during interrupt handling, silenced: {e}") + pass + if progress is not None: + progress.finish() except Exception as error: self._handle_error(error) @@ -1361,7 +1410,8 @@ class SparkConnectClient(object): schema: Optional[StructType] = None properties: Dict[str, Any] = {} - for response in self._execute_and_fetch_as_iterator(req, observations): + progress = Progress(handlers=self._progress_handlers) + for response in self._execute_and_fetch_as_iterator(req, observations, progress=progress): if isinstance(response, StructType): schema = response elif isinstance(response, pa.RecordBatch): @@ -1379,6 +1429,7 @@ class SparkConnectClient(object): "response": response, }, ) + progress.finish() if len(batches) > 0: if self_destruct: diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 2943057a99fc..b9f88aab3c26 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -120,105 +120,109 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 5196 _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5361 _EXECUTEPLANRESPONSE._serialized_start = 5440 - _EXECUTEPLANRESPONSE._serialized_end = 7791 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6927 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6998 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7000 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 7118 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 7121 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 7638 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 7216 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 7548 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 7425 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 7548 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7550 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7638 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7640 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7756 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7758 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7774 - _KEYVALUE._serialized_start = 7793 - _KEYVALUE._serialized_end = 7858 - _CONFIGREQUEST._serialized_start = 7861 - _CONFIGREQUEST._serialized_end = 9020 - _CONFIGREQUEST_OPERATION._serialized_start = 8169 - _CONFIGREQUEST_OPERATION._serialized_end = 8667 - _CONFIGREQUEST_SET._serialized_start = 8669 - _CONFIGREQUEST_SET._serialized_end = 8721 - _CONFIGREQUEST_GET._serialized_start = 8723 - _CONFIGREQUEST_GET._serialized_end = 8748 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8750 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8813 - _CONFIGREQUEST_GETOPTION._serialized_start = 8815 - _CONFIGREQUEST_GETOPTION._serialized_end = 8846 - _CONFIGREQUEST_GETALL._serialized_start = 8848 - _CONFIGREQUEST_GETALL._serialized_end = 8896 - _CONFIGREQUEST_UNSET._serialized_start = 8898 - _CONFIGREQUEST_UNSET._serialized_end = 8925 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8927 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8961 - _CONFIGRESPONSE._serialized_start = 9023 - _CONFIGRESPONSE._serialized_end = 9198 - _ADDARTIFACTSREQUEST._serialized_start = 9201 - _ADDARTIFACTSREQUEST._serialized_end = 10203 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 9676 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 9729 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 9731 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 9842 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 9844 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9937 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9940 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10133 - _ADDARTIFACTSRESPONSE._serialized_start = 10206 - _ADDARTIFACTSRESPONSE._serialized_end = 10478 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10397 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10478 - _ARTIFACTSTATUSESREQUEST._serialized_start = 10481 - _ARTIFACTSTATUSESREQUEST._serialized_end = 10807 - _ARTIFACTSTATUSESRESPONSE._serialized_start = 10810 - _ARTIFACTSTATUSESRESPONSE._serialized_end = 11162 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11005 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11120 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11122 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11162 - _INTERRUPTREQUEST._serialized_start = 11165 - _INTERRUPTREQUEST._serialized_end = 11768 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 11568 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 11696 - _INTERRUPTRESPONSE._serialized_start = 11771 - _INTERRUPTRESPONSE._serialized_end = 11915 - _REATTACHOPTIONS._serialized_start = 11917 - _REATTACHOPTIONS._serialized_end = 11970 - _REATTACHEXECUTEREQUEST._serialized_start = 11973 - _REATTACHEXECUTEREQUEST._serialized_end = 12379 - _RELEASEEXECUTEREQUEST._serialized_start = 12382 - _RELEASEEXECUTEREQUEST._serialized_end = 12967 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 12836 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 12848 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 12850 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 12897 - _RELEASEEXECUTERESPONSE._serialized_start = 12970 - _RELEASEEXECUTERESPONSE._serialized_end = 13135 - _RELEASESESSIONREQUEST._serialized_start = 13138 - _RELEASESESSIONREQUEST._serialized_end = 13309 - _RELEASESESSIONRESPONSE._serialized_start = 13311 - _RELEASESESSIONRESPONSE._serialized_end = 13419 - _FETCHERRORDETAILSREQUEST._serialized_start = 13422 - _FETCHERRORDETAILSREQUEST._serialized_end = 13754 - _FETCHERRORDETAILSRESPONSE._serialized_start = 13757 - _FETCHERRORDETAILSRESPONSE._serialized_end = 15312 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 13986 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14160 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14163 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14531 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14494 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14531 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14534 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 14943 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 14845 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 14913 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 14946 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15293 - _SPARKCONNECTSERVICE._serialized_start = 15315 - _SPARKCONNECTSERVICE._serialized_end = 16261 + _EXECUTEPLANRESPONSE._serialized_end = 8230 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 7030 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 7101 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7103 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 7221 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 7224 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 7741 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 7319 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 7651 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 7528 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 7651 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7653 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7741 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7743 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7859 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7861 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7877 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7880 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8213 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8036 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8213 + _KEYVALUE._serialized_start = 8232 + _KEYVALUE._serialized_end = 8297 + _CONFIGREQUEST._serialized_start = 8300 + _CONFIGREQUEST._serialized_end = 9459 + _CONFIGREQUEST_OPERATION._serialized_start = 8608 + _CONFIGREQUEST_OPERATION._serialized_end = 9106 + _CONFIGREQUEST_SET._serialized_start = 9108 + _CONFIGREQUEST_SET._serialized_end = 9160 + _CONFIGREQUEST_GET._serialized_start = 9162 + _CONFIGREQUEST_GET._serialized_end = 9187 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9189 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9252 + _CONFIGREQUEST_GETOPTION._serialized_start = 9254 + _CONFIGREQUEST_GETOPTION._serialized_end = 9285 + _CONFIGREQUEST_GETALL._serialized_start = 9287 + _CONFIGREQUEST_GETALL._serialized_end = 9335 + _CONFIGREQUEST_UNSET._serialized_start = 9337 + _CONFIGREQUEST_UNSET._serialized_end = 9364 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9366 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9400 + _CONFIGRESPONSE._serialized_start = 9462 + _CONFIGRESPONSE._serialized_end = 9637 + _ADDARTIFACTSREQUEST._serialized_start = 9640 + _ADDARTIFACTSREQUEST._serialized_end = 10642 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10115 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10168 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10170 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10281 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10283 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10376 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10379 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10572 + _ADDARTIFACTSRESPONSE._serialized_start = 10645 + _ADDARTIFACTSRESPONSE._serialized_end = 10917 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10836 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10917 + _ARTIFACTSTATUSESREQUEST._serialized_start = 10920 + _ARTIFACTSTATUSESREQUEST._serialized_end = 11246 + _ARTIFACTSTATUSESRESPONSE._serialized_start = 11249 + _ARTIFACTSTATUSESRESPONSE._serialized_end = 11601 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11444 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11559 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11561 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11601 + _INTERRUPTREQUEST._serialized_start = 11604 + _INTERRUPTREQUEST._serialized_end = 12207 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12007 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12135 + _INTERRUPTRESPONSE._serialized_start = 12210 + _INTERRUPTRESPONSE._serialized_end = 12354 + _REATTACHOPTIONS._serialized_start = 12356 + _REATTACHOPTIONS._serialized_end = 12409 + _REATTACHEXECUTEREQUEST._serialized_start = 12412 + _REATTACHEXECUTEREQUEST._serialized_end = 12818 + _RELEASEEXECUTEREQUEST._serialized_start = 12821 + _RELEASEEXECUTEREQUEST._serialized_end = 13406 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13275 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13287 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13289 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13336 + _RELEASEEXECUTERESPONSE._serialized_start = 13409 + _RELEASEEXECUTERESPONSE._serialized_end = 13574 + _RELEASESESSIONREQUEST._serialized_start = 13577 + _RELEASESESSIONREQUEST._serialized_end = 13748 + _RELEASESESSIONRESPONSE._serialized_start = 13750 + _RELEASESESSIONRESPONSE._serialized_end = 13858 + _FETCHERRORDETAILSREQUEST._serialized_start = 13861 + _FETCHERRORDETAILSREQUEST._serialized_end = 14193 + _FETCHERRORDETAILSRESPONSE._serialized_start = 14196 + _FETCHERRORDETAILSRESPONSE._serialized_end = 15751 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14425 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14599 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14602 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14970 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14933 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14970 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14973 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15382 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 15284 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 15352 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15385 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15732 + _SPARKCONNECTSERVICE._serialized_start = 15754 + _SPARKCONNECTSERVICE._serialized_end = 16700 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 562977331952..d22502f8839d 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -1224,7 +1224,7 @@ global___ExecutePlanRequest = ExecutePlanRequest class ExecutePlanResponse(google.protobuf.message.Message): """The response of a query, can be one or more for each request. Responses belonging to the same input query, carry the same `session_id`. - Next ID: 16 + Next ID: 17 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1446,6 +1446,76 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, ) -> None: ... + class ExecutionProgress(google.protobuf.message.Message): + """This message is used to communicate progress about the query progress during the execution.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class StageInfo(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STAGE_ID_FIELD_NUMBER: builtins.int + NUM_TASKS_FIELD_NUMBER: builtins.int + NUM_COMPLETED_TASKS_FIELD_NUMBER: builtins.int + INPUT_BYTES_READ_FIELD_NUMBER: builtins.int + DONE_FIELD_NUMBER: builtins.int + stage_id: builtins.int + num_tasks: builtins.int + num_completed_tasks: builtins.int + input_bytes_read: builtins.int + done: builtins.bool + def __init__( + self, + *, + stage_id: builtins.int = ..., + num_tasks: builtins.int = ..., + num_completed_tasks: builtins.int = ..., + input_bytes_read: builtins.int = ..., + done: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "done", + b"done", + "input_bytes_read", + b"input_bytes_read", + "num_completed_tasks", + b"num_completed_tasks", + "num_tasks", + b"num_tasks", + "stage_id", + b"stage_id", + ], + ) -> None: ... + + STAGES_FIELD_NUMBER: builtins.int + NUM_INFLIGHT_TASKS_FIELD_NUMBER: builtins.int + @property + def stages( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___ExecutePlanResponse.ExecutionProgress.StageInfo + ]: + """Captures the progress of each individual stage.""" + num_inflight_tasks: builtins.int + """Captures the currently in progress tasks.""" + def __init__( + self, + *, + stages: collections.abc.Iterable[ + global___ExecutePlanResponse.ExecutionProgress.StageInfo + ] + | None = ..., + num_inflight_tasks: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "num_inflight_tasks", b"num_inflight_tasks", "stages", b"stages" + ], + ) -> None: ... + SESSION_ID_FIELD_NUMBER: builtins.int SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int OPERATION_ID_FIELD_NUMBER: builtins.int @@ -1459,6 +1529,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): STREAMING_QUERY_LISTENER_EVENTS_RESULT_FIELD_NUMBER: builtins.int RESULT_COMPLETE_FIELD_NUMBER: builtins.int CREATE_RESOURCE_PROFILE_COMMAND_RESULT_FIELD_NUMBER: builtins.int + EXECUTION_PROGRESS_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int OBSERVED_METRICS_FIELD_NUMBER: builtins.int @@ -1517,6 +1588,9 @@ class ExecutePlanResponse(google.protobuf.message.Message): ) -> pyspark.sql.connect.proto.commands_pb2.CreateResourceProfileCommandResult: """Response for command that creates ResourceProfile.""" @property + def execution_progress(self) -> global___ExecutePlanResponse.ExecutionProgress: + """(Optional) Intermediate query progress reports.""" + @property def extension(self) -> google.protobuf.any_pb2.Any: """Support arbitrary result objects.""" @property @@ -1556,6 +1630,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): result_complete: global___ExecutePlanResponse.ResultComplete | None = ..., create_resource_profile_command_result: pyspark.sql.connect.proto.commands_pb2.CreateResourceProfileCommandResult | None = ..., + execution_progress: global___ExecutePlanResponse.ExecutionProgress | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., observed_metrics: collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics] @@ -1569,6 +1644,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"arrow_batch", "create_resource_profile_command_result", b"create_resource_profile_command_result", + "execution_progress", + b"execution_progress", "extension", b"extension", "get_resources_command_result", @@ -1600,6 +1677,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"arrow_batch", "create_resource_profile_command_result", b"create_resource_profile_command_result", + "execution_progress", + b"execution_progress", "extension", b"extension", "get_resources_command_result", @@ -1647,6 +1726,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): "streaming_query_listener_events_result", "result_complete", "create_resource_profile_command_result", + "execution_progress", "extension", ] | None diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 13cad30bbff9..b19c420c3833 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -102,6 +102,7 @@ if TYPE_CHECKING: from pyspark.sql.connect.catalog import Catalog from pyspark.sql.connect.udf import UDFRegistration from pyspark.sql.connect.udtf import UDTFRegistration + from pyspark.sql.connect.shell.progress import ProgressHandler from pyspark.sql.connect.datasource import DataSourceRegistration @@ -325,6 +326,21 @@ class SparkSession: readStream.__doc__ = PySparkSession.readStream.__doc__ + def registerProgressHandler(self, handler: "ProgressHandler") -> None: + self._client.register_progress_handler(handler) + + registerProgressHandler.__doc__ = PySparkSession.registerProgressHandler.__doc__ + + def removeProgressHandler(self, handler: "ProgressHandler") -> None: + self._client.remove_progress_handler(handler) + + removeProgressHandler.__doc__ = PySparkSession.removeProgressHandler.__doc__ + + def clearProgressHandlers(self) -> None: + self._client.clear_progress_handlers() + + clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__ + def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: diff --git a/python/pyspark/sql/connect/shell/__init__.py b/python/pyspark/sql/connect/shell/__init__.py new file mode 100644 index 000000000000..b99733bffa0a --- /dev/null +++ b/python/pyspark/sql/connect/shell/__init__.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Helpers for integration with the IPython Shell""" + +import os + +PROGRESS_BAR_ENABLED = "SPARK_CONNECT_PROGRESS_BAR_ENABLED" + + +def progress_bar_enabled() -> bool: + return os.getenv(PROGRESS_BAR_ENABLED, "0") == "1" diff --git a/python/pyspark/sql/connect/shell/progress.py b/python/pyspark/sql/connect/shell/progress.py new file mode 100644 index 000000000000..8a8064c29cdc --- /dev/null +++ b/python/pyspark/sql/connect/shell/progress.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Implementation of a progress bar that is displayed while a query is running.""" +import abc +from dataclasses import dataclass +import time +import sys +import typing +from typing import Iterable, Any + +from pyspark.sql.connect.proto import ExecutePlanResponse + +try: + from IPython.utils.terminal import get_terminal_size +except ImportError: + + def get_terminal_size(defaultx: Any = None, defaulty: Any = None) -> Any: + return (80, 25) + + +from pyspark.sql.connect.shell import progress_bar_enabled + + +@dataclass +class StageInfo: + stage_id: int + num_tasks: int + num_completed_tasks: int + num_bytes_read: int + done: bool + + +class ProgressHandler(abc.ABC): + @abc.abstractmethod + def __call__( + self, + stages: typing.Optional[Iterable[StageInfo]], + inflight_tasks: int, + done: bool, + ) -> None: + pass + + +def from_proto( + proto: ExecutePlanResponse.ExecutionProgress, +) -> typing.Tuple[Iterable[StageInfo], int]: + result = [] + for stage in proto.stages: + result.append( + StageInfo( + stage_id=stage.stage_id, + num_tasks=stage.num_tasks, + num_completed_tasks=stage.num_completed_tasks, + num_bytes_read=stage.input_bytes_read, + done=stage.done, + ) + ) + return (result, proto.num_inflight_tasks) + + +class Progress: + """This is a small helper class to visualize a textual progress bar. + he interface is very simple and assumes that nothing else prints to the + standard output.""" + + SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1) + SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B") + + def __init__( + self, + char: str = "*", + min_width: int = 80, + output: typing.IO = sys.stdout, + enabled: bool = False, + handlers: Iterable[ProgressHandler] = [], + ) -> None: + """ + Constructs a new Progress bar. The progress bar is typically used in + the blocking query execution path to process the execution progress + methods from the server. + + Parameters + ---------- + char : str + The Default character to be used for printing the bar. + min_width : numeric + The minimum width of the progress bar + output : file + The output device to write the progress bar to. + enabled : bool + Whether the progress bar printing should be enabled or not. + handlers : list of ProgressHandler + A list of handlers that will be called when the progress bar is updated. + """ + self._ticks = 0 + self._tick = 0 + x, y = get_terminal_size() + self._min_width = min_width + self._char = char + self._width = max(min(min_width, x), self._min_width) + self._max_printed = 0 + self._started = time.time() + self._enabled = enabled or progress_bar_enabled() + self._bytes_read = 0 + self._out = output + self._running = 0 + self._handlers = handlers + self._stages: Iterable[StageInfo] = [] + + def _notify(self, done: bool = False) -> None: + for handler in self._handlers: + handler( + stages=self._stages, + inflight_tasks=self._running, + done=done, + ) + + def update_ticks(self, stages: Iterable[StageInfo], inflight_tasks: int) -> None: + """This method is called from the execution to update the progress bar with a new total + tick counter and the current position. This is necessary in case new stages get added with + new tasks and so the total task number will be updated as well. + + Parameters + ---------- + stages : list + A list of StageInfo objects reporting progress in each stage. + inflight_tasks : int + The number of tasks that are currently running. + """ + total_tasks = sum(map(lambda x: x.num_tasks, stages)) + completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages)) + if total_tasks > 0 and completed_tasks != self._tick: + self._ticks = total_tasks + self._tick = completed_tasks + self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages)) + if self._tick > 0: + self.output() + self._running = inflight_tasks + self._stages = stages + self._notify(False) + + def finish(self) -> None: + """Clear the last line. Called when the processing is done.""" + self._notify(True) + if self._enabled: + print("\r" + " " * self._max_printed, end="", flush=True, file=self._out) + print("\r", end="", flush=True, file=self._out) + + def output(self) -> None: + """Writes the progress bar out.""" + if self._enabled: + val = int((self._tick / float(self._ticks)) * self._width) + bar = self._char * val + "-" * (self._width - val) + percent_complete = (self._tick / self._ticks) * 100 + elapsed = int(time.time() - self._started) + scanned = self._bytes_to_string(self._bytes_read) + running = self._running + buffer = ( + f"\r[{bar}] {percent_complete:.2f}% Complete " + f"({running} Tasks running, {elapsed}s, Scanned {scanned})" + ) + self._max_printed = max(len(buffer), self._max_printed) + print(buffer, end="", flush=True, file=self._out) + + @staticmethod + def _bytes_to_string(size: int) -> str: + """Helper method to convert a numeric bytes value into a human-readable representation""" + i = 0 + while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * Progress.SI_BYTE_SIZES[i]: + i += 1 + result = float(size) / Progress.SI_BYTE_SIZES[i] + return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}" diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f065a106bbf2..0cc2d7d3f13e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: # Running MyPy type checks will always require pandas and # other dependencies so importing here is fine. from pyspark.sql.connect.client import SparkConnectClient + from pyspark.sql.connect.shell.progress import ProgressHandler try: import memory_profiler # noqa: F401 @@ -2029,6 +2030,61 @@ class SparkSession(SparkConversionMixin): addArtifact = addArtifacts + def registerProgressHandler(self, handler: "ProgressHandler") -> None: + """ + Register a progress handler to be called when a progress update is received from the server. + + .. versionadded:: 4.0 + + Parameters + ---------- + handler : ProgressHandler + A callable that follows the ProgressHandler interface. This handler will be called + on every progress update. + + Examples + -------- + + >>> def progress_handler(stages, inflight_tasks, done): + ... print(f"{len(stages)} Stages known, Done: {done}") + >>> spark.registerProgressHandler(progress_handler) + >>> res = spark.range(10).repartition(1).collect() + 3 Stages known, Done: False + 3 Stages known, Done: True + >>> spark.clearProgressHandlers() + """ + raise PySparkRuntimeError( + error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT", + message_parameters={"feature": "SparkSession.registerProgressHandler"}, + ) + + def removeProgressHandler(self, handler: "ProgressHandler") -> None: + """ + Remove a progress handler that was previously registered. + + .. versionadded:: 4.0 + + Parameters + ---------- + handler : ProgressHandler + The handler to remove if present in the list of progress handlers. + """ + raise PySparkRuntimeError( + error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT", + message_parameters={"feature": "SparkSession.removeProgressHandler"}, + ) + + def clearProgressHandlers(self) -> None: + """ + Clear all registered progress handlers. + + .. versionadded:: 4.0 + """ + raise PySparkRuntimeError( + error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT", + message_parameters={"feature": "SparkSession.clearProgressHandlers"}, + ) + def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: """ Copy file from local to cloud storage file system. @@ -2194,6 +2250,11 @@ def _test() -> None: os.chdir(os.environ["SPARK_HOME"]) + # Disable Doc Tests for Spark Connect only functions: + pyspark.sql.session.SparkSession.registerProgressHandler.__doc__ = None + pyspark.sql.session.SparkSession.removeProgressHandler.__doc__ = None + pyspark.sql.session.SparkSession.clearProgressHandlers.__doc__ = None + globs = pyspark.sql.session.__dict__.copy() globs["spark"] = ( SparkSession.builder.master("local[4]").appName("sql.session tests").getOrCreate() diff --git a/python/pyspark/sql/tests/connect/shell/__init__.py b/python/pyspark/sql/tests/connect/shell/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/python/pyspark/sql/tests/connect/shell/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/connect/shell/test_progress.py b/python/pyspark/sql/tests/connect/shell/test_progress.py new file mode 100644 index 000000000000..7d99a699eefa --- /dev/null +++ b/python/pyspark/sql/tests/connect/shell/test_progress.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from io import StringIO +import unittest +from typing import Iterable + +from pyspark.testing.connectutils import ( + should_test_connect, + connect_requirement_message, +) +from pyspark.testing.utils import PySparkErrorTestUtils + +if should_test_connect: + from pyspark.sql.connect.shell.progress import Progress, StageInfo + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class ProgressBarTest(unittest.TestCase, PySparkErrorTestUtils): + def test_simple_progress(self): + stages = [StageInfo(0, 100, 50, 999, False)] + buffer = StringIO() + p = Progress(output=buffer, enabled=True) + p.update_ticks(stages, 10) + val = buffer.getvalue() + self.assertIn("50.00%", val, "Current progress is 50%") + self.assertIn("****", val, "Should use the default char to print.") + self.assertIn("Scanned 999.0 B", val, "Should contain the bytes scanned metric.") + self.assertFalse(val.endswith("\r"), "Line should not be empty") + p.finish() + val = buffer.getvalue() + self.assertTrue(val.endswith("\r"), "Line should be empty") + + def test_configure_char(self): + stages = [StageInfo(0, 100, 50, 999, False)] + buffer = StringIO() + p = Progress(char="+", output=buffer, enabled=True) + p.update_ticks(stages, 10) + val = buffer.getvalue() + self.assertIn("++++++", val, "Updating the char works.") + + def test_disabled_does_not_print(self): + stages = [StageInfo(0, 100, 50, 999, False)] + buffer = StringIO() + p = Progress(char="+", output=buffer, enabled=False) + p.update_ticks(stages, 10) + stages = [StageInfo(0, 100, 51, 999, False)] + p.update_ticks(stages, 10) + val = buffer.getvalue() + self.assertEqual(0, len(val), "If the printing is disabled, don't print.") + + def test_finish_progress(self): + stages = [StageInfo(0, 100, 50, 999, False)] + buffer = StringIO() + p = Progress(char="+", output=buffer, enabled=True) + p.update_ticks(stages, 10) + p.finish() + self.assertTrue(buffer.getvalue().endswith("\r"), "Last line should be empty") + + def test_progress_handler(self): + stages = [StageInfo(0, 0, 0, 0, False)] + + handler_called = 0 + done_called = False + + def handler(stages: Iterable[StageInfo], inflight_tasks: int, done: bool): + nonlocal handler_called, done_called + handler_called = 1 + self.assertEqual(100, sum(map(lambda x: x.num_tasks, stages))) + self.assertEqual(50, sum(map(lambda x: x.num_completed_tasks, stages))) + self.assertEqual(999, sum(map(lambda x: x.num_bytes_read, stages))) + self.assertEqual(10, inflight_tasks) + done_called = done + + buffer = StringIO() + p = Progress(char="+", output=buffer, enabled=True, handlers=[handler]) + p.update_ticks(stages, 1) + stages = [StageInfo(0, 100, 50, 999, False)] + p.update_ticks(stages, 10) + self.assertIn("++++++", buffer.getvalue(), "Updating the char works.") + self.assertEqual(1, handler_called, "Handler should be called.") + self.assertFalse(done_called, "Before finish, done should be False") + p.finish() + self.assertTrue(buffer.getvalue().endswith("\r"), "Last line should be empty") + self.assertTrue(done_called, "After finish, done should be True") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.shell.test_progress import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index bebe2cfc2923..b73a56340984 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -58,6 +58,27 @@ class SparkConnectSessionTests(ReusedConnectTestCase): def tearDown(self): self.spark.stop() + def test_progress_handler(self): + handler_called = [] + + def handler(**kwargs): + nonlocal handler_called + handler_called.append(kwargs) + + self.spark.registerProgressHandler(handler) + self.spark.sql("select 1").collect() + self.assertGreaterEqual(len(handler_called), 1) + + handler_called = [] + self.spark.removeProgressHandler(handler) + self.spark.sql("select 1").collect() + self.assertEqual(len(handler_called), 0) + + self.spark.registerProgressHandler(handler) + self.spark.clearProgressHandlers() + self.spark.sql("select 1").collect() + self.assertGreaterEqual(len(handler_called), 0) + def _check_no_active_session_error(self, e: PySparkException): self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict()) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org