This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6999df0c7f0 [SPARK-47081][CONNECT] Support Query Execution Progress
f6999df0c7f0 is described below

commit f6999df0c7f0bb18778b29ebdbe9f7d40899808a
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Thu Apr 4 12:59:56 2024 +0800

    [SPARK-47081][CONNECT] Support Query Execution Progress
    
    ### What changes were proposed in this pull request?
    This patch adss a new mechanism to push query execution progress for batch 
queries. We add a new response message type and periodically push query 
progress to the client. The client can consume this data to for example display 
a progress bar.
    
    This patch adds support for displaying a progress bar in the PySpark shell 
when started with Spark Connect.
    
    The proto message is defined as follows:
    
    ```
    // This message is used to communicate progress about the query progress 
during the execution.
      // This message is used to communicate progress about the query progress 
during the execution.
      message ExecutionProgress {
        // Captures the progress of each individual stage.
        repeated StageInfo stages = 1;
    
        // Captures the currently in progress tasks.
        int64 num_inflight_tasks = 2;
    
        message StageInfo {
          int64 stage_id = 1;
          int64 num_tasks = 2;
          int64 num_completed_tasks = 3;
          int64 input_bytes_read = 4;
          bool done = 5;
        }
      }
    ```
    
    Clients can simply ignore the messages or consume them. On top of that this 
adds additional capabilities to register a callback for progress tracking to 
the SparkSession.
    
    ```
    handler = lambda **kwargs: print(kwargs)
    spark.register_progress_handler(handler)
    spark.range(100).collect()
    spark.remove_progress_handler(handler)
    ```
    
    #### Example 1
    ![progress_medium_query_multi_stage 
mp4](https://github.com/apache/spark/assets/3421/5eff1ec4-def2-4d39-8a75-13a6af784c99)
    
    #### Example 2
    ![progress_bar 
mp4](https://github.com/apache/spark/assets/3421/20638511-2da4-4bd6-83f2-da3b9f500bde)
    
    ### Why are the changes needed?
    Usability and Experience
    
    ### Does this PR introduce _any_ user-facing change?
    When the user opens the PySpark shell with Spark Connect mode, it will use 
the progress bar by default.
    
    ### How was this patch tested?
    Added new tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45150 from grundprinzip/SPARK-47081.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/SparkSessionE2ESuite.scala    |  10 +
 .../src/main/protobuf/spark/connect/base.proto     |  22 ++-
 .../spark/sql/connect/client/SparkResult.scala     |  39 ++++
 .../apache/spark/sql/connect/config/Connect.scala  |   8 +
 .../ConnectProgressExecutionListener.scala         | 191 +++++++++++++++++++
 .../execution/ExecuteGrpcResponseSender.scala      |  51 ++++-
 .../execution/ExecuteResponseObserver.scala        |  11 +-
 .../connect/execution/ExecuteThreadRunner.scala    |   5 +-
 .../sql/connect/service/SparkConnectService.scala  |   5 +
 .../ConnectProgressExecutionListenerSuite.scala    | 156 ++++++++++++++++
 .../org/apache/spark/deploy/SparkSubmit.scala      |   7 +-
 dev/sparktestsupport/modules.py                    |   1 +
 .../source/reference/pyspark.sql/spark_session.rst |  13 +-
 python/pyspark/shell.py                            |  21 ++-
 python/pyspark/sql/connect/client/core.py          |  55 +++++-
 python/pyspark/sql/connect/proto/base_pb2.py       | 208 +++++++++++----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  82 +++++++-
 python/pyspark/sql/connect/session.py              |  16 ++
 python/pyspark/sql/connect/shell/__init__.py       |  26 +++
 python/pyspark/sql/connect/shell/progress.py       | 187 ++++++++++++++++++
 python/pyspark/sql/session.py                      |  61 ++++++
 python/pyspark/sql/tests/connect/shell/__init__.py |  16 ++
 .../sql/tests/connect/shell/test_progress.py       | 111 +++++++++++
 .../sql/tests/connect/test_connect_session.py      |  21 +++
 24 files changed, 1204 insertions(+), 119 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index e4cbcf620d15..b967245d90c2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -229,6 +229,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
     assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
   }
 
+  test("progress is available for the spark result") {
+    val result = spark
+      .range(10000)
+      .repartition(1000)
+      .collectResult()
+    assert(result.length == 10000)
+    assert(result.progress.stages.map(_.numTasks).sum > 100)
+    assert(result.progress.stages.map(_.completedTasks).sum > 100)
+  }
+
   test("interrupt operation") {
     val session = spark
     import session.implicits._
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 9a9121d84f76..49a33d3419b6 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -333,7 +333,7 @@ message ExecutePlanRequest {
 
 // The response of a query, can be one or more for each request. Responses 
belonging to the
 // same input query, carry the same `session_id`.
-// Next ID: 16
+// Next ID: 17
 message ExecutePlanResponse {
   string session_id = 1;
   // Server-side generated idempotency key that the client can use to assert 
that the server side
@@ -378,6 +378,9 @@ message ExecutePlanResponse {
     // Response for command that creates ResourceProfile.
     CreateResourceProfileCommandResult create_resource_profile_command_result 
= 17;
 
+    // (Optional) Intermediate query progress reports.
+    ExecutionProgress execution_progress = 18;
+
     // Support arbitrary result objects.
     google.protobuf.Any extension = 999;
   }
@@ -438,6 +441,23 @@ message ExecutePlanResponse {
     // the execution is complete. If the server sends onComplete without 
sending a ResultComplete,
     // it means that there is more, and the client should use ReattachExecute 
RPC to continue.
   }
+
+  // This message is used to communicate progress about the query progress 
during the execution.
+  message ExecutionProgress {
+    // Captures the progress of each individual stage.
+    repeated StageInfo stages = 1;
+
+    // Captures the currently in progress tasks.
+    int64 num_inflight_tasks = 2;
+
+    message StageInfo {
+      int64 stage_id = 1;
+      int64 num_tasks = 2;
+      int64 num_completed_tasks = 3;
+      int64 input_bytes_read = 4;
+      bool done = 5;
+    }
+  }
 }
 
 // The key-value pair for the config request and response.
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 7a7c6a2d6c92..93d1075aea02 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -20,6 +20,7 @@ import java.lang.ref.Cleaner
 import java.util.Objects
 
 import scala.collection.mutable
+import scala.jdk.CollectionConverters._
 
 import org.apache.arrow.memory.BufferAllocator
 import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
@@ -40,6 +41,38 @@ private[sql] class SparkResult[T](
     timeZoneId: String)
     extends AutoCloseable { self =>
 
+  case class StageInfo(
+      stageId: Long,
+      numTasks: Long,
+      completedTasks: Long = 0,
+      inputBytesRead: Long = 0,
+      completed: Boolean = false)
+
+  object StageInfo {
+    def apply(stageInfo: 
proto.ExecutePlanResponse.ExecutionProgress.StageInfo): StageInfo = {
+      StageInfo(
+        stageInfo.getStageId,
+        stageInfo.getNumTasks,
+        stageInfo.getNumCompletedTasks,
+        stageInfo.getInputBytesRead,
+        stageInfo.getDone)
+    }
+  }
+
+  object Progress {
+    def apply(progress: proto.ExecutePlanResponse.ExecutionProgress): Progress 
= {
+      Progress(
+        progress.getStagesList.asScala.map(StageInfo(_)).toSeq,
+        progress.getNumInflightTasks)
+    }
+  }
+
+  /**
+   * Progress of the query execution. This information can be accessed from 
the iterator.
+   */
+  case class Progress(stages: Seq[StageInfo], inflight: Long)
+
+  var progress: Progress = new Progress(Seq.empty, 0)
   private[this] var opId: String = _
   private[this] var numRecords: Int = 0
   private[this] var structType: StructType = _
@@ -97,6 +130,12 @@ private[sql] class SparkResult[T](
       }
       stop |= stopOnOperationId
 
+      // Update the execution status. This information can now be accessed 
directly from
+      // the iterator.
+      if (response.hasExecutionProgress) {
+        progress = Progress(response.getExecutionProgress)
+      }
+
       if (response.hasSchema) {
         // The original schema should arrive before ArrowBatches.
         structType =
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 39bf1a630af6..6ba100af1bb9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -265,4 +265,12 @@ object Connect {
       .version("4.0.0")
       .bytesConf(ByteUnit.BYTE)
       .createWithDefault(1024)
+
+  val CONNECT_PROGRESS_REPORT_INTERVAL =
+    buildConf("spark.connect.progress.reportInterval")
+      .doc("The interval at which the progress of a query is reported to the 
client." +
+        " If the value is set to a negative value the progress reports will be 
disabled.")
+      .version("4.0.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("2s")
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
new file mode 100644
index 000000000000..954956363505
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.execution
+
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
+
+import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, 
SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, 
SparkListenerTaskStart}
+
+/**
+ * A listener that tracks the execution of jobs and stages for a given set of 
tags. This is used
+ * to track the progress of a job that is being executed through the connect 
API.
+ *
+ * The listener is instantiated once for the SparkConnectService and then used 
to track all the
+ * current query executions.
+ */
+private[connect] class ConnectProgressExecutionListener extends SparkListener 
with Logging {
+
+  /**
+   * A tracker for a given tag. This is used to track the progress of an 
operation is being
+   * executed through the connect API.
+   */
+  class ExecutionTracker(val tag: String) {
+
+    class StageInfo(
+        val stageId: Int,
+        var numTasks: Int,
+        var completedTasks: Int = 0,
+        var inputBytesRead: Long = 0,
+        var completed: Boolean = false) {
+
+      val lock = new Object
+      def update(i: StageInfo => Unit): Unit = {
+        lock.synchronized {
+          i(this)
+        }
+      }
+
+      def toProto(): ExecutePlanResponse.ExecutionProgress.StageInfo = {
+        ExecutePlanResponse.ExecutionProgress.StageInfo
+          .newBuilder()
+          .setStageId(stageId)
+          .setNumTasks(numTasks)
+          .setNumCompletedTasks(completedTasks)
+          .setInputBytesRead(inputBytesRead)
+          .setDone(completed)
+          .build()
+      }
+    }
+
+    // The set of jobs that are being tracked by this tracker. We always only 
add to this list
+    // but never remove. This is to avoid concurrency issues.
+    private[ConnectProgressExecutionListener] var jobs: Set[Int] = Set()
+    // The set of stages that are being tracked by this tracker. We always 
only add to this list
+    // but never remove. This is to avoid concurrency issues.
+    private[ConnectProgressExecutionListener] var stages: Map[Int, StageInfo] 
= Map.empty
+    // The tracker is marked as dirty if it has new progress to report.
+    private[ConnectProgressExecutionListener] val dirty = new 
AtomicBoolean(false)
+    // Tracks all currently running tasks for a particular tracker.
+    private[ConnectProgressExecutionListener] val inFlightTasks = new 
AtomicInteger(0)
+
+    /**
+     * Yield the current state of the tracker if it is dirty. A consumer of 
the tracker can
+     * provide a callback that will be called with the current state of the 
tracker if the tracker
+     * has new progress to report.
+     *
+     * If the tracker was marked as dirty, the state is reset after.
+     */
+    def yieldWhenDirty(thunk: (Seq[StageInfo], Long) => Unit): Unit = {
+      if (dirty.get()) {
+        thunk(stages.values.toSeq, inFlightTasks.get())
+        dirty.set(false)
+      }
+    }
+
+    /**
+     * Add a job to the tracker. This will add the job to the list of jobs 
that are being tracked
+     */
+    def addJob(job: SparkListenerJobStart): Unit = synchronized {
+      jobs = jobs + job.jobId
+      job.stageInfos.foreach { stage =>
+        stages = stages + (stage.stageId -> new StageInfo(stage.stageId, 
stage.numTasks))
+      }
+      dirty.set(true)
+    }
+
+    def jobCount(): Int = {
+      jobs.size
+    }
+
+    def stageCount(): Int = {
+      stages.size
+    }
+  }
+
+  val trackedTags = collection.concurrent.TrieMap[String, ExecutionTracker]()
+
+  override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+    val tags = jobStart.properties.getProperty("spark.job.tags")
+    if (tags != null) {
+      val thisJobTags = tags.split(",").map(_.trim).toSet
+      thisJobTags.foreach { tag =>
+        trackedTags.get(tag).foreach { tracker =>
+          tracker.addJob(jobStart)
+        }
+      }
+    }
+  }
+
+  override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+    // Check if the task belongs to a job that we are tracking.
+    trackedTags.foreach({ case (_, tracker) =>
+      if (tracker.stages.contains(taskStart.stageId)) {
+        tracker.inFlightTasks.incrementAndGet()
+        tracker.dirty.set(true)
+      }
+    })
+  }
+
+  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+    // Check if the task belongs to a job that we are tracking.
+    trackedTags.foreach({ case (_, tracker) =>
+      if (tracker.stages.contains(taskEnd.stageId)) {
+        tracker.stages.get(taskEnd.stageId).foreach { stage =>
+          stage.update { i =>
+            i.completedTasks += 1
+            i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
+          }
+        }
+        // This should never become negative, simply reset to zero if it does.
+        tracker.inFlightTasks.decrementAndGet()
+        if (tracker.inFlightTasks.get() < 0) {
+          tracker.inFlightTasks.set(0)
+        }
+        tracker.dirty.set(true)
+      }
+    })
+  }
+
+  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): 
Unit = {
+    trackedTags.foreach({ case (_, tracker) =>
+      if (tracker.stages.contains(stageCompleted.stageInfo.stageId)) {
+        tracker.stages(stageCompleted.stageInfo.stageId).update { stage =>
+          stage.completed = true
+        }
+        tracker.dirty.set(true)
+      }
+    })
+  }
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+    trackedTags.foreach({ case (_, tracker) =>
+      if (tracker.jobs.contains(jobEnd.jobId)) {
+        tracker.dirty.set(true)
+      }
+    })
+  }
+
+  def tryGetTracker(tag: String): Option[ExecutionTracker] = {
+    trackedTags.get(tag)
+  }
+
+  def registerJobTag(tag: String): Unit = {
+    trackedTags += tag -> new ExecutionTracker(tag)
+  }
+
+  def removeJobTag(tag: String): Unit = {
+    trackedTags -= tag
+  }
+
+  def clearJobTags(): Unit = {
+    trackedTags.clear()
+  }
+
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
index c9ceef969e29..a9444862b3aa 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
@@ -17,13 +17,16 @@
 
 package org.apache.spark.sql.connect.execution
 
+import scala.jdk.CollectionConverters._
+
 import com.google.protobuf.Message
 import io.grpc.stub.{ServerCallStreamObserver, StreamObserver}
 
 import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.connect.proto.ExecutePlanResponse
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.common.ProtoUtils
-import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION,
 CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE}
+import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION,
 CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE, 
CONNECT_PROGRESS_REPORT_INTERVAL}
 import org.apache.spark.sql.connect.service.{ExecuteHolder, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
 
@@ -131,6 +134,38 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
     }
   }
 
+  /**
+   * This method is called repeatedly during the query execution to enqueue a 
new message to be
+   * send to the client about the current query progress. The message is not 
directly send to the
+   * client, but rather enqueued to in the response observer.
+   */
+  private def enqueueProgressMessage(): Unit = {
+    if 
(executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) 
> 0) {
+      SparkConnectService.executionListener.foreach { listener =>
+        // It is possible, that the tracker is no longer available and in this
+        // case we simply ignore it and do not send any progress message. This 
avoids
+        // having to synchronize on the listener.
+        listener.tryGetTracker(executeHolder.jobTag).foreach { tracker =>
+          // Only send progress message if there is something new to report.
+          tracker.yieldWhenDirty { (stages, inflightTasks) =>
+            val response = ExecutePlanResponse
+              .newBuilder()
+              .setExecutionProgress(
+                ExecutePlanResponse.ExecutionProgress
+                  .newBuilder()
+                  .addAllStages(stages.map(_.toProto()).asJava)
+                  .setNumInflightTasks(inflightTasks))
+              .build()
+            // There is a special case when the response observer has alreaady 
determined
+            // that the final message is send (and the stream will be closed) 
but we might want
+            // to send the progress message. In this case we ignore the result 
of the `onNext` call.
+            executeHolder.responseObserver.tryOnNext(response)
+          }
+        }
+      }
+    }
+  }
+
   /**
    * Attach to the executionObserver, consume responses from it, and send them 
to grpcObserver.
    *
@@ -173,6 +208,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
     var sentResponsesSize: Long = 0
 
     while (!finished) {
+      enqueueProgressMessage()
       var response: Option[CachedStreamResponse[T]] = None
 
       // Conditions for exiting the inner loop (and helpers to compute them):
@@ -201,9 +237,18 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
           // The state of interrupted, response and lastIndex are changed 
under executionObserver
           // monitor, and will notify upon state change.
           if (response.isEmpty) {
-            val timeout = Math.max(1, deadlineTimeMillis - 
System.currentTimeMillis())
+            // Wake up more frequently to send the progress updates.
+            val progressTimeout =
+              
executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL)
+            // If the progress feature is disabled, wait for the deadline.
+            val timeout = if (progressTimeout > 0) {
+              progressTimeout
+            } else {
+              Math.max(1, deadlineTimeMillis - System.currentTimeMillis())
+            }
             logTrace(s"Wait for response to become available with 
timeout=$timeout ms.")
             executionObserver.responseLock.wait(timeout)
+            enqueueProgressMessage()
             logTrace(s"Reacquired executionObserver lock after waiting.")
             sleepEnd = System.nanoTime()
           }
@@ -228,6 +273,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
             s"waitingForResults=${consumeSleep}ns 
waitingForSend=${sendSleep}ns")
         throw new SparkSQLException(errorClass = 
"INVALID_CURSOR.DISCONNECTED", Map.empty)
       } else if (gotResponse) {
+        enqueueProgressMessage()
         // There is a response available to be sent.
         val sent = sendResponse(response.get, deadlineTimeMillis)
         if (sent) {
@@ -240,6 +286,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: 
Message](
           assert(deadlineLimitReached || interrupted)
         }
       } else if (streamFinished) {
+        enqueueProgressMessage()
         // Stream is finished and all responses have been sent
         logInfo(
           s"Stream finished for opId=${executeHolder.operationId}, " +
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
index a7877503f461..92c23c6165d2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
@@ -107,9 +107,9 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
     0
   }
 
-  def onNext(r: T): Unit = responseLock.synchronized {
+  def tryOnNext(r: T): Boolean = responseLock.synchronized {
     if (finalProducedIndex.nonEmpty) {
-      throw new IllegalStateException("Stream onNext can't be called after 
stream completed")
+      return false
     }
     lastProducedIndex += 1
     val processedResponse = setCommonResponseFields(r)
@@ -127,6 +127,13 @@ private[connect] class ExecuteResponseObserver[T <: 
Message](val executeHolder:
       s"Execution opId=${executeHolder.operationId} produced response " +
         s"responseId=${responseId} idx=$lastProducedIndex")
     responseLock.notifyAll()
+    true
+  }
+
+  def onNext(r: T): Unit = {
+    if (!tryOnNext(r)) {
+      throw new IllegalStateException("Stream onNext can't be called after 
stream completed")
+    }
   }
 
   def onError(t: Throwable): Unit = responseLock.synchronized {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 41146e4ef688..56776819dac9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -28,7 +28,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
-import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag}
+import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
 import org.apache.spark.util.Utils
 
@@ -123,6 +123,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
           }
       } finally {
         
executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag)
+        
SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag))
         executeHolder.sparkSessionTags.foreach { tag =>
           executeHolder.sessionHolder.session.sparkContext.removeJobTag(
             ExecuteSessionTag(
@@ -158,6 +159,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
 
       // Set tag for query cancellation
       session.sparkContext.addJobTag(executeHolder.jobTag)
+      // Register the job for progress reports.
+      
SparkConnectService.executionListener.foreach(_.registerJobTag(executeHolder.jobTag))
       // Also set all user defined tags as Spark Job tags.
       executeHolder.sparkSessionTags.foreach { tag =>
         session.sparkContext.addJobTag(
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 9324e8e6c5f1..476254bc6e39 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -38,6 +38,7 @@ import 
org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, 
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, 
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
+import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
 import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, 
SparkConnectServerListener, SparkConnectServerTab}
 import org.apache.spark.sql.connect.utils.ErrorUtils
 import org.apache.spark.status.ElementTrackingStore
@@ -284,6 +285,7 @@ object SparkConnectService extends Logging {
 
   private[connect] var uiTab: Option[SparkConnectServerTab] = None
   private[connect] var listener: SparkConnectServerListener = _
+  private[connect] var executionListener: 
Option[ConnectProgressExecutionListener] = None
 
   // For testing purpose, it's package level private.
   private[connect] def localPort: Int = {
@@ -330,6 +332,9 @@ object SparkConnectService extends Logging {
     } else {
       None
     }
+    // Add the execution listener needed for query progress.
+    executionListener = Some(new ConnectProgressExecutionListener)
+    sc.addSparkListener(executionListener.get)
   }
 
   /**
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
new file mode 100644
index 000000000000..43e978a18f1f
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.execution
+
+import java.util.Properties
+
+import org.mockito.Mockito.when
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.{SparkFunSuite, Success}
+import org.apache.spark.executor.{ExecutorMetrics, InputMetrics, TaskMetrics}
+import org.apache.spark.scheduler.{SparkListenerJobStart, 
SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart, 
StageInfo, TaskInfo}
+
+class ConnectProgressExecutionListenerSuite extends SparkFunSuite with 
MockitoSugar {
+
+  def mockStage(stageId: Int, numTasks: Int): StageInfo = {
+    val result = mock[StageInfo]
+    when(result.stageId).thenReturn(stageId)
+    when(result.numTasks).thenReturn(numTasks)
+    result
+  }
+
+  val testTag = "testTag"
+  val testStage1 = mockStage(1, 1)
+  val testStage2 = mockStage(2, 1)
+
+  val testStage1Task1 = mock[TaskInfo]
+  val testStage1Task1ExecutorMetrics = mock[ExecutorMetrics]
+  val testStage1Task1Metrics = mock[TaskMetrics]
+
+  val inputMetrics = mock[InputMetrics]
+  when(inputMetrics.bytesRead).thenReturn(500)
+  when(testStage1Task1Metrics.inputMetrics).thenReturn(inputMetrics)
+
+  val testStage2Task1 = mock[TaskInfo]
+//
+  val testProperties = new Properties()
+  testProperties.setProperty("spark.job.tags", s"otherTag,$testTag,anotherTag")
+
+  val testJobStart = SparkListenerJobStart(1, 1, Seq(testStage1, testStage2), 
testProperties)
+  val testTaskStart = SparkListenerTaskStart(1, 1, testStage1Task1)
+
+  test("onJobStart with no matching tags") {
+    val listener = new ConnectProgressExecutionListener
+    listener.onJobStart(testJobStart)
+    assert(listener.trackedTags.isEmpty)
+  }
+
+  test("onJobStart with a registered tag") {
+    val listener = new ConnectProgressExecutionListener
+    listener.registerJobTag(testTag)
+    assert(listener.trackedTags.size == 1)
+
+    // Trigger the event
+    listener.onJobStart(testJobStart)
+    val t = listener.trackedTags(testTag)
+
+    t.yieldWhenDirty((stages, inflight) => {
+      assert(stages.map(_.numTasks).sum == 2)
+      assert(stages.map(_.completedTasks).sum == 0)
+      assert(stages.size == 2)
+      assert(stages.map(_.inputBytesRead).sum == 0)
+      assert(inflight == 0)
+    })
+  }
+
+  test("taskDone") {
+    val listener = new ConnectProgressExecutionListener
+    listener.registerJobTag(testTag)
+    listener.onJobStart(testJobStart)
+
+    // Finish the tasks
+    val taskEnd = SparkListenerTaskEnd(
+      1,
+      1,
+      "taskType",
+      Success,
+      testStage1Task1,
+      testStage1Task1ExecutorMetrics,
+      testStage1Task1Metrics)
+
+    val t = listener.trackedTags(testTag)
+    var yielded = false
+    t.yieldWhenDirty { (stages, inflight) =>
+      assert(stages.map(_.numTasks).sum == 2)
+      assert(stages.map(_.completedTasks).sum == 0)
+      assert(stages.size == 2)
+      assert(
+        stages
+          .map(_.completed match {
+            case true => 1
+            case false => 0
+          })
+          .sum == 0)
+      yielded = true
+    }
+    assert(yielded, "Must updated with results")
+
+    yielded = false
+    listener.onTaskEnd(taskEnd)
+    t.yieldWhenDirty { (stages, inflight) =>
+      assert(stages.map(_.numTasks).sum == 2)
+      assert(stages.map(_.completedTasks).sum == 1)
+      assert(stages.size == 2)
+      assert(stages.map(_.inputBytesRead).sum == 500)
+      assert(
+        stages
+          .map(_.completed match {
+            case true => 1
+            case false => 0
+          })
+          .sum == 0)
+      yielded = true
+    }
+    assert(yielded, "Must updated with results")
+    yielded = false
+    t.yieldWhenDirty { (stages, inflight) =>
+      yielded = true
+    }
+    assert(!yielded, "Must not update if not dirty")
+
+    val stageEnd = SparkListenerStageCompleted(testStage1)
+    listener.onStageCompleted(stageEnd)
+    t.yieldWhenDirty { (stages, inflight) =>
+      assert(stages.map(_.numTasks).sum == 2)
+      assert(stages.map(_.completedTasks).sum == 1)
+      assert(stages.size == 2)
+      assert(stages.map(_.inputBytesRead).sum == 500)
+      assert(
+        stages
+          .map(_.completed match {
+            case true => 1
+            case false => 0
+          })
+          .sum == 1)
+      yielded = true
+    }
+    assert(yielded, "Must updated with results")
+  }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala 
b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 9ab394741a82..c0df74f8d0cc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -744,8 +744,11 @@ private[spark] class SparkSubmit extends Logging {
       }
     }
 
-    // In case of shells, spark.ui.showConsoleProgress can be true by default 
or by user.
-    if (isShell(args.primaryResource) && 
!sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) {
+    // In case of shells, spark.ui.showConsoleProgress can be true by default 
or by user. Except,
+    // when Spark Connect is in local mode, because Spark Connect support its 
own progress
+    // reporting.
+    if (isShell(args.primaryResource) && 
!sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS) &&
+        !sparkConf.contains("spark.local.connect")) {
       sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true)
     }
 
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 6b087436c687..d3ffa79ebe68 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1060,6 +1060,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
         "pyspark.sql.tests.connect.test_parity_pandas_udf_window",
         "pyspark.sql.tests.connect.test_resources",
+        "pyspark.sql.tests.connect.shell.test_progress",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst 
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index ea71249e292e..4e679da59c16 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -78,12 +78,15 @@ Spark Connect Only
 
     SparkSession.addArtifact
     SparkSession.addArtifacts
-    SparkSession.copyFromLocalToFs
+    SparkSession.addTag
+    SparkSession.clearProgressHandlers
+    SparkSession.clearTags
     SparkSession.client
+    SparkSession.copyFromLocalToFs
+    SparkSession.getTags
     SparkSession.interruptAll
-    SparkSession.interruptTag
     SparkSession.interruptOperation
-    SparkSession.addTag
+    SparkSession.interruptTag
+    SparkSession.registerProgressHandler
+    SparkSession.removeProgressHandler
     SparkSession.removeTag
-    SparkSession.getTags
-    SparkSession.clearTags
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index f705f0edd8fe..12ff86ecc9ff 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -45,11 +45,30 @@ if getattr(builtins, "__IPYTHON__", False):
         if parent_dir in sys.path:
             sys.path.remove(parent_dir)
 
-
 if is_remote():
     try:
         # Creates pyspark.sql.connect.SparkSession.
         spark = SparkSession.builder.getOrCreate()
+
+        from pyspark.sql.connect.shell import PROGRESS_BAR_ENABLED
+
+        # Check if th eprogress bar needs to be disabled.
+        if PROGRESS_BAR_ENABLED not in os.environ:
+            os.environ[PROGRESS_BAR_ENABLED] = "1"
+        else:
+            val = os.getenv(PROGRESS_BAR_ENABLED, "false")
+            if val.lower().strip() == "false":
+                os.environ[PROGRESS_BAR_ENABLED] = "0"
+            elif val.lower().strip() == "true":
+                os.environ[PROGRESS_BAR_ENABLED] = "1"
+
+        val = os.environ[PROGRESS_BAR_ENABLED]
+        if val not in ("1", "0"):
+            raise ValueError(
+                f"Environment variable '{PROGRESS_BAR_ENABLED}' must "
+                f"be set to either 1 or 0, found: {val}"
+            )
+
     except Exception:
         import sys
         import traceback
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index b3807d80f6c9..17b5d99aba94 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -93,6 +93,7 @@ from pyspark.sql.types import DataType, StructType, 
TimestampType, _has_type
 from pyspark.util import PythonEvalType
 from pyspark.storagelevel import StorageLevel
 from pyspark.errors import PySparkValueError, PySparkAssertionError, 
PySparkNotImplementedError
+from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, 
from_proto
 
 if TYPE_CHECKING:
     from google.rpc.error_details_pb2 import ErrorInfo
@@ -694,6 +695,37 @@ class SparkConnectClient(object):
 
         self._profiler_collector = ConnectProfilerCollector()
 
+        self._progress_handlers: List[ProgressHandler] = []
+
+    def register_progress_handler(self, handler: ProgressHandler) -> None:
+        """
+        Register a progress handler to be called when a progress message is 
received.
+
+        Parameters
+        ----------
+        handler : ProgressHandler
+          The callable that will be called with the progress information.
+
+        """
+        if handler in self._progress_handlers:
+            return
+        self._progress_handlers.append(handler)
+
+    def clear_progress_handlers(self) -> None:
+        self._progress_handlers.clear()
+
+    def remove_progress_handler(self, handler: ProgressHandler) -> None:
+        """
+        Remove a progress handler from the list of registered handlers.
+
+        Parameters
+        ----------
+        handler : ProgressHandler
+          The callable to remove from the list of progress handlers.
+
+        """
+        self._progress_handlers.remove(handler)
+
     def _retrying(self) -> "Retrying":
         return Retrying(self._retry_policies)
 
@@ -1213,7 +1245,10 @@ class SparkConnectClient(object):
             self._handle_error(error)
 
     def _execute_and_fetch_as_iterator(
-        self, req: pb2.ExecutePlanRequest, observations: Dict[str, Observation]
+        self,
+        req: pb2.ExecutePlanRequest,
+        observations: Dict[str, Observation],
+        progress: Optional["Progress"] = None,
     ) -> Iterator[
         Union[
             "pa.RecordBatch",
@@ -1292,6 +1327,10 @@ class SparkConnectClient(object):
                 yield {"get_resources_command_result": resources}
             if b.HasField("extension"):
                 yield b.extension
+            if b.HasField("execution_progress"):
+                if progress:
+                    p = from_proto(b.execution_progress)
+                    progress.update_ticks(*p)
             if b.HasField("arrow_batch"):
                 logger.debug(
                     f"Received arrow batch rows={b.arrow_batch.row_count} "
@@ -1338,6 +1377,16 @@ class SparkConnectClient(object):
                     with attempt:
                         for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
                             yield from handle_response(b)
+        except KeyboardInterrupt:
+            logger.debug(f"Interrupt request received for 
operation={req.operation_id}")
+            try:
+                self.interrupt_operation(req.operation_id)
+            except Exception as e:
+                # Swallow all errors if aborted.
+                logger.debug(f"Caught an error during interrupt handling, 
silenced: {e}")
+                pass
+            if progress is not None:
+                progress.finish()
         except Exception as error:
             self._handle_error(error)
 
@@ -1361,7 +1410,8 @@ class SparkConnectClient(object):
         schema: Optional[StructType] = None
         properties: Dict[str, Any] = {}
 
-        for response in self._execute_and_fetch_as_iterator(req, observations):
+        progress = Progress(handlers=self._progress_handlers)
+        for response in self._execute_and_fetch_as_iterator(req, observations, 
progress=progress):
             if isinstance(response, StructType):
                 schema = response
             elif isinstance(response, pa.RecordBatch):
@@ -1379,6 +1429,7 @@ class SparkConnectClient(object):
                         "response": response,
                     },
                 )
+        progress.finish()
 
         if len(batches) > 0:
             if self_destruct:
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 2943057a99fc..b9f88aab3c26 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -120,105 +120,109 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 5196
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5361
     _EXECUTEPLANRESPONSE._serialized_start = 5440
-    _EXECUTEPLANRESPONSE._serialized_end = 7791
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6927
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6998
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7000
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 7118
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 7121
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 7638
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 7216
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 7548
-    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
 = 7425
-    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end 
= 7548
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7550
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7638
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7640
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7756
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7758
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7774
-    _KEYVALUE._serialized_start = 7793
-    _KEYVALUE._serialized_end = 7858
-    _CONFIGREQUEST._serialized_start = 7861
-    _CONFIGREQUEST._serialized_end = 9020
-    _CONFIGREQUEST_OPERATION._serialized_start = 8169
-    _CONFIGREQUEST_OPERATION._serialized_end = 8667
-    _CONFIGREQUEST_SET._serialized_start = 8669
-    _CONFIGREQUEST_SET._serialized_end = 8721
-    _CONFIGREQUEST_GET._serialized_start = 8723
-    _CONFIGREQUEST_GET._serialized_end = 8748
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8750
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8813
-    _CONFIGREQUEST_GETOPTION._serialized_start = 8815
-    _CONFIGREQUEST_GETOPTION._serialized_end = 8846
-    _CONFIGREQUEST_GETALL._serialized_start = 8848
-    _CONFIGREQUEST_GETALL._serialized_end = 8896
-    _CONFIGREQUEST_UNSET._serialized_start = 8898
-    _CONFIGREQUEST_UNSET._serialized_end = 8925
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8927
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8961
-    _CONFIGRESPONSE._serialized_start = 9023
-    _CONFIGRESPONSE._serialized_end = 9198
-    _ADDARTIFACTSREQUEST._serialized_start = 9201
-    _ADDARTIFACTSREQUEST._serialized_end = 10203
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 9676
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 9729
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 9731
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 9842
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 9844
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9937
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9940
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10133
-    _ADDARTIFACTSRESPONSE._serialized_start = 10206
-    _ADDARTIFACTSRESPONSE._serialized_end = 10478
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10397
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10478
-    _ARTIFACTSTATUSESREQUEST._serialized_start = 10481
-    _ARTIFACTSTATUSESREQUEST._serialized_end = 10807
-    _ARTIFACTSTATUSESRESPONSE._serialized_start = 10810
-    _ARTIFACTSTATUSESRESPONSE._serialized_end = 11162
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11005
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11120
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11122
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11162
-    _INTERRUPTREQUEST._serialized_start = 11165
-    _INTERRUPTREQUEST._serialized_end = 11768
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 11568
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 11696
-    _INTERRUPTRESPONSE._serialized_start = 11771
-    _INTERRUPTRESPONSE._serialized_end = 11915
-    _REATTACHOPTIONS._serialized_start = 11917
-    _REATTACHOPTIONS._serialized_end = 11970
-    _REATTACHEXECUTEREQUEST._serialized_start = 11973
-    _REATTACHEXECUTEREQUEST._serialized_end = 12379
-    _RELEASEEXECUTEREQUEST._serialized_start = 12382
-    _RELEASEEXECUTEREQUEST._serialized_end = 12967
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 12836
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 12848
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 12850
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 12897
-    _RELEASEEXECUTERESPONSE._serialized_start = 12970
-    _RELEASEEXECUTERESPONSE._serialized_end = 13135
-    _RELEASESESSIONREQUEST._serialized_start = 13138
-    _RELEASESESSIONREQUEST._serialized_end = 13309
-    _RELEASESESSIONRESPONSE._serialized_start = 13311
-    _RELEASESESSIONRESPONSE._serialized_end = 13419
-    _FETCHERRORDETAILSREQUEST._serialized_start = 13422
-    _FETCHERRORDETAILSREQUEST._serialized_end = 13754
-    _FETCHERRORDETAILSRESPONSE._serialized_start = 13757
-    _FETCHERRORDETAILSRESPONSE._serialized_end = 15312
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 13986
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14160
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14163
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14531
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
14494
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14531
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14534
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 14943
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 14845
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 14913
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 14946
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15293
-    _SPARKCONNECTSERVICE._serialized_start = 15315
-    _SPARKCONNECTSERVICE._serialized_end = 16261
+    _EXECUTEPLANRESPONSE._serialized_end = 8230
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 7030
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 7101
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7103
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 7221
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 7224
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 7741
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 7319
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 7651
+    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
 = 7528
+    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end 
= 7651
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7653
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7741
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7743
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7859
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7861
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7877
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7880
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8213
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8036
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8213
+    _KEYVALUE._serialized_start = 8232
+    _KEYVALUE._serialized_end = 8297
+    _CONFIGREQUEST._serialized_start = 8300
+    _CONFIGREQUEST._serialized_end = 9459
+    _CONFIGREQUEST_OPERATION._serialized_start = 8608
+    _CONFIGREQUEST_OPERATION._serialized_end = 9106
+    _CONFIGREQUEST_SET._serialized_start = 9108
+    _CONFIGREQUEST_SET._serialized_end = 9160
+    _CONFIGREQUEST_GET._serialized_start = 9162
+    _CONFIGREQUEST_GET._serialized_end = 9187
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9189
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9252
+    _CONFIGREQUEST_GETOPTION._serialized_start = 9254
+    _CONFIGREQUEST_GETOPTION._serialized_end = 9285
+    _CONFIGREQUEST_GETALL._serialized_start = 9287
+    _CONFIGREQUEST_GETALL._serialized_end = 9335
+    _CONFIGREQUEST_UNSET._serialized_start = 9337
+    _CONFIGREQUEST_UNSET._serialized_end = 9364
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9366
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9400
+    _CONFIGRESPONSE._serialized_start = 9462
+    _CONFIGRESPONSE._serialized_end = 9637
+    _ADDARTIFACTSREQUEST._serialized_start = 9640
+    _ADDARTIFACTSREQUEST._serialized_end = 10642
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10115
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10168
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10170
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10281
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10283
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10376
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10379
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10572
+    _ADDARTIFACTSRESPONSE._serialized_start = 10645
+    _ADDARTIFACTSRESPONSE._serialized_end = 10917
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10836
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10917
+    _ARTIFACTSTATUSESREQUEST._serialized_start = 10920
+    _ARTIFACTSTATUSESREQUEST._serialized_end = 11246
+    _ARTIFACTSTATUSESRESPONSE._serialized_start = 11249
+    _ARTIFACTSTATUSESRESPONSE._serialized_end = 11601
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11444
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11559
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11561
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11601
+    _INTERRUPTREQUEST._serialized_start = 11604
+    _INTERRUPTREQUEST._serialized_end = 12207
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12007
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12135
+    _INTERRUPTRESPONSE._serialized_start = 12210
+    _INTERRUPTRESPONSE._serialized_end = 12354
+    _REATTACHOPTIONS._serialized_start = 12356
+    _REATTACHOPTIONS._serialized_end = 12409
+    _REATTACHEXECUTEREQUEST._serialized_start = 12412
+    _REATTACHEXECUTEREQUEST._serialized_end = 12818
+    _RELEASEEXECUTEREQUEST._serialized_start = 12821
+    _RELEASEEXECUTEREQUEST._serialized_end = 13406
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13275
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13287
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13289
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13336
+    _RELEASEEXECUTERESPONSE._serialized_start = 13409
+    _RELEASEEXECUTERESPONSE._serialized_end = 13574
+    _RELEASESESSIONREQUEST._serialized_start = 13577
+    _RELEASESESSIONREQUEST._serialized_end = 13748
+    _RELEASESESSIONRESPONSE._serialized_start = 13750
+    _RELEASESESSIONRESPONSE._serialized_end = 13858
+    _FETCHERRORDETAILSREQUEST._serialized_start = 13861
+    _FETCHERRORDETAILSREQUEST._serialized_end = 14193
+    _FETCHERRORDETAILSRESPONSE._serialized_start = 14196
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 15751
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14425
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14599
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14602
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14970
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
14933
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14970
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14973
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15382
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 15284
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 15352
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15385
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15732
+    _SPARKCONNECTSERVICE._serialized_start = 15754
+    _SPARKCONNECTSERVICE._serialized_end = 16700
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 562977331952..d22502f8839d 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1224,7 +1224,7 @@ global___ExecutePlanRequest = ExecutePlanRequest
 class ExecutePlanResponse(google.protobuf.message.Message):
     """The response of a query, can be one or more for each request. Responses 
belonging to the
     same input query, carry the same `session_id`.
-    Next ID: 16
+    Next ID: 17
     """
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -1446,6 +1446,76 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             self,
         ) -> None: ...
 
+    class ExecutionProgress(google.protobuf.message.Message):
+        """This message is used to communicate progress about the query 
progress during the execution."""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        class StageInfo(google.protobuf.message.Message):
+            DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+            STAGE_ID_FIELD_NUMBER: builtins.int
+            NUM_TASKS_FIELD_NUMBER: builtins.int
+            NUM_COMPLETED_TASKS_FIELD_NUMBER: builtins.int
+            INPUT_BYTES_READ_FIELD_NUMBER: builtins.int
+            DONE_FIELD_NUMBER: builtins.int
+            stage_id: builtins.int
+            num_tasks: builtins.int
+            num_completed_tasks: builtins.int
+            input_bytes_read: builtins.int
+            done: builtins.bool
+            def __init__(
+                self,
+                *,
+                stage_id: builtins.int = ...,
+                num_tasks: builtins.int = ...,
+                num_completed_tasks: builtins.int = ...,
+                input_bytes_read: builtins.int = ...,
+                done: builtins.bool = ...,
+            ) -> None: ...
+            def ClearField(
+                self,
+                field_name: typing_extensions.Literal[
+                    "done",
+                    b"done",
+                    "input_bytes_read",
+                    b"input_bytes_read",
+                    "num_completed_tasks",
+                    b"num_completed_tasks",
+                    "num_tasks",
+                    b"num_tasks",
+                    "stage_id",
+                    b"stage_id",
+                ],
+            ) -> None: ...
+
+        STAGES_FIELD_NUMBER: builtins.int
+        NUM_INFLIGHT_TASKS_FIELD_NUMBER: builtins.int
+        @property
+        def stages(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            global___ExecutePlanResponse.ExecutionProgress.StageInfo
+        ]:
+            """Captures the progress of each individual stage."""
+        num_inflight_tasks: builtins.int
+        """Captures the currently in progress tasks."""
+        def __init__(
+            self,
+            *,
+            stages: collections.abc.Iterable[
+                global___ExecutePlanResponse.ExecutionProgress.StageInfo
+            ]
+            | None = ...,
+            num_inflight_tasks: builtins.int = ...,
+        ) -> None: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "num_inflight_tasks", b"num_inflight_tasks", "stages", 
b"stages"
+            ],
+        ) -> None: ...
+
     SESSION_ID_FIELD_NUMBER: builtins.int
     SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
     OPERATION_ID_FIELD_NUMBER: builtins.int
@@ -1459,6 +1529,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
     STREAMING_QUERY_LISTENER_EVENTS_RESULT_FIELD_NUMBER: builtins.int
     RESULT_COMPLETE_FIELD_NUMBER: builtins.int
     CREATE_RESOURCE_PROFILE_COMMAND_RESULT_FIELD_NUMBER: builtins.int
+    EXECUTION_PROGRESS_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     METRICS_FIELD_NUMBER: builtins.int
     OBSERVED_METRICS_FIELD_NUMBER: builtins.int
@@ -1517,6 +1588,9 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
     ) -> 
pyspark.sql.connect.proto.commands_pb2.CreateResourceProfileCommandResult:
         """Response for command that creates ResourceProfile."""
     @property
+    def execution_progress(self) -> 
global___ExecutePlanResponse.ExecutionProgress:
+        """(Optional) Intermediate query progress reports."""
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """Support arbitrary result objects."""
     @property
@@ -1556,6 +1630,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         result_complete: global___ExecutePlanResponse.ResultComplete | None = 
...,
         create_resource_profile_command_result: 
pyspark.sql.connect.proto.commands_pb2.CreateResourceProfileCommandResult
         | None = ...,
+        execution_progress: global___ExecutePlanResponse.ExecutionProgress | 
None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
         metrics: global___ExecutePlanResponse.Metrics | None = ...,
         observed_metrics: 
collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
@@ -1569,6 +1644,8 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             b"arrow_batch",
             "create_resource_profile_command_result",
             b"create_resource_profile_command_result",
+            "execution_progress",
+            b"execution_progress",
             "extension",
             b"extension",
             "get_resources_command_result",
@@ -1600,6 +1677,8 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             b"arrow_batch",
             "create_resource_profile_command_result",
             b"create_resource_profile_command_result",
+            "execution_progress",
+            b"execution_progress",
             "extension",
             b"extension",
             "get_resources_command_result",
@@ -1647,6 +1726,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             "streaming_query_listener_events_result",
             "result_complete",
             "create_resource_profile_command_result",
+            "execution_progress",
             "extension",
         ]
         | None
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 13cad30bbff9..b19c420c3833 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -102,6 +102,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.catalog import Catalog
     from pyspark.sql.connect.udf import UDFRegistration
     from pyspark.sql.connect.udtf import UDTFRegistration
+    from pyspark.sql.connect.shell.progress import ProgressHandler
     from pyspark.sql.connect.datasource import DataSourceRegistration
 
 
@@ -325,6 +326,21 @@ class SparkSession:
 
     readStream.__doc__ = PySparkSession.readStream.__doc__
 
+    def registerProgressHandler(self, handler: "ProgressHandler") -> None:
+        self._client.register_progress_handler(handler)
+
+    registerProgressHandler.__doc__ = 
PySparkSession.registerProgressHandler.__doc__
+
+    def removeProgressHandler(self, handler: "ProgressHandler") -> None:
+        self._client.remove_progress_handler(handler)
+
+    removeProgressHandler.__doc__ = 
PySparkSession.removeProgressHandler.__doc__
+
+    def clearProgressHandlers(self) -> None:
+        self._client.clear_progress_handlers()
+
+    clearProgressHandlers.__doc__ = 
PySparkSession.clearProgressHandlers.__doc__
+
     def _inferSchemaFromList(
         self, data: Iterable[Any], names: Optional[List[str]] = None
     ) -> StructType:
diff --git a/python/pyspark/sql/connect/shell/__init__.py 
b/python/pyspark/sql/connect/shell/__init__.py
new file mode 100644
index 000000000000..b99733bffa0a
--- /dev/null
+++ b/python/pyspark/sql/connect/shell/__init__.py
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Helpers for integration with the IPython Shell"""
+
+import os
+
+PROGRESS_BAR_ENABLED = "SPARK_CONNECT_PROGRESS_BAR_ENABLED"
+
+
+def progress_bar_enabled() -> bool:
+    return os.getenv(PROGRESS_BAR_ENABLED, "0") == "1"
diff --git a/python/pyspark/sql/connect/shell/progress.py 
b/python/pyspark/sql/connect/shell/progress.py
new file mode 100644
index 000000000000..8a8064c29cdc
--- /dev/null
+++ b/python/pyspark/sql/connect/shell/progress.py
@@ -0,0 +1,187 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implementation of a progress bar that is displayed while a query is 
running."""
+import abc
+from dataclasses import dataclass
+import time
+import sys
+import typing
+from typing import Iterable, Any
+
+from pyspark.sql.connect.proto import ExecutePlanResponse
+
+try:
+    from IPython.utils.terminal import get_terminal_size
+except ImportError:
+
+    def get_terminal_size(defaultx: Any = None, defaulty: Any = None) -> Any:
+        return (80, 25)
+
+
+from pyspark.sql.connect.shell import progress_bar_enabled
+
+
+@dataclass
+class StageInfo:
+    stage_id: int
+    num_tasks: int
+    num_completed_tasks: int
+    num_bytes_read: int
+    done: bool
+
+
+class ProgressHandler(abc.ABC):
+    @abc.abstractmethod
+    def __call__(
+        self,
+        stages: typing.Optional[Iterable[StageInfo]],
+        inflight_tasks: int,
+        done: bool,
+    ) -> None:
+        pass
+
+
+def from_proto(
+    proto: ExecutePlanResponse.ExecutionProgress,
+) -> typing.Tuple[Iterable[StageInfo], int]:
+    result = []
+    for stage in proto.stages:
+        result.append(
+            StageInfo(
+                stage_id=stage.stage_id,
+                num_tasks=stage.num_tasks,
+                num_completed_tasks=stage.num_completed_tasks,
+                num_bytes_read=stage.input_bytes_read,
+                done=stage.done,
+            )
+        )
+    return (result, proto.num_inflight_tasks)
+
+
+class Progress:
+    """This is a small helper class to visualize a textual progress bar.
+    he interface is very simple and assumes that nothing else prints to the
+    standard output."""
+
+    SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1)
+    SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B")
+
+    def __init__(
+        self,
+        char: str = "*",
+        min_width: int = 80,
+        output: typing.IO = sys.stdout,
+        enabled: bool = False,
+        handlers: Iterable[ProgressHandler] = [],
+    ) -> None:
+        """
+        Constructs a new Progress bar. The progress bar is typically used in
+        the blocking query execution path to process the execution progress
+        methods from the server.
+
+        Parameters
+        ----------
+        char : str
+          The Default character to be used for printing the bar.
+        min_width : numeric
+          The minimum width of the progress bar
+        output : file
+          The output device to write the progress bar to.
+        enabled : bool
+          Whether the progress bar printing should be enabled or not.
+        handlers : list of ProgressHandler
+          A list of handlers that will be called when the progress bar is 
updated.
+        """
+        self._ticks = 0
+        self._tick = 0
+        x, y = get_terminal_size()
+        self._min_width = min_width
+        self._char = char
+        self._width = max(min(min_width, x), self._min_width)
+        self._max_printed = 0
+        self._started = time.time()
+        self._enabled = enabled or progress_bar_enabled()
+        self._bytes_read = 0
+        self._out = output
+        self._running = 0
+        self._handlers = handlers
+        self._stages: Iterable[StageInfo] = []
+
+    def _notify(self, done: bool = False) -> None:
+        for handler in self._handlers:
+            handler(
+                stages=self._stages,
+                inflight_tasks=self._running,
+                done=done,
+            )
+
+    def update_ticks(self, stages: Iterable[StageInfo], inflight_tasks: int) 
-> None:
+        """This method is called from the execution to update the progress bar 
with a new total
+        tick counter and the current position. This is necessary in case new 
stages get added with
+        new tasks and so the total task number will be updated as well.
+
+        Parameters
+        ----------
+        stages : list
+          A list of StageInfo objects reporting progress in each stage.
+        inflight_tasks : int
+          The number of tasks that are currently running.
+        """
+        total_tasks = sum(map(lambda x: x.num_tasks, stages))
+        completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages))
+        if total_tasks > 0 and completed_tasks != self._tick:
+            self._ticks = total_tasks
+            self._tick = completed_tasks
+            self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages))
+            if self._tick > 0:
+                self.output()
+            self._running = inflight_tasks
+            self._stages = stages
+            self._notify(False)
+
+    def finish(self) -> None:
+        """Clear the last line. Called when the processing is done."""
+        self._notify(True)
+        if self._enabled:
+            print("\r" + " " * self._max_printed, end="", flush=True, 
file=self._out)
+            print("\r", end="", flush=True, file=self._out)
+
+    def output(self) -> None:
+        """Writes the progress bar out."""
+        if self._enabled:
+            val = int((self._tick / float(self._ticks)) * self._width)
+            bar = self._char * val + "-" * (self._width - val)
+            percent_complete = (self._tick / self._ticks) * 100
+            elapsed = int(time.time() - self._started)
+            scanned = self._bytes_to_string(self._bytes_read)
+            running = self._running
+            buffer = (
+                f"\r[{bar}] {percent_complete:.2f}% Complete "
+                f"({running} Tasks running, {elapsed}s, Scanned {scanned})"
+            )
+            self._max_printed = max(len(buffer), self._max_printed)
+            print(buffer, end="", flush=True, file=self._out)
+
+    @staticmethod
+    def _bytes_to_string(size: int) -> str:
+        """Helper method to convert a numeric bytes value into a 
human-readable representation"""
+        i = 0
+        while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * 
Progress.SI_BYTE_SIZES[i]:
+            i += 1
+        result = float(size) / Progress.SI_BYTE_SIZES[i]
+        return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f065a106bbf2..0cc2d7d3f13e 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -81,6 +81,7 @@ if TYPE_CHECKING:
     # Running MyPy type checks will always require pandas and
     # other dependencies so importing here is fine.
     from pyspark.sql.connect.client import SparkConnectClient
+    from pyspark.sql.connect.shell.progress import ProgressHandler
 
 try:
     import memory_profiler  # noqa: F401
@@ -2029,6 +2030,61 @@ class SparkSession(SparkConversionMixin):
 
     addArtifact = addArtifacts
 
+    def registerProgressHandler(self, handler: "ProgressHandler") -> None:
+        """
+        Register a progress handler to be called when a progress update is 
received from the server.
+
+        .. versionadded:: 4.0
+
+        Parameters
+        ----------
+        handler : ProgressHandler
+          A callable that follows the ProgressHandler interface. This handler 
will be called
+          on every progress update.
+
+        Examples
+        --------
+
+        >>> def progress_handler(stages, inflight_tasks, done):
+        ...     print(f"{len(stages)} Stages known, Done: {done}")
+        >>> spark.registerProgressHandler(progress_handler)
+        >>> res = spark.range(10).repartition(1).collect()
+        3 Stages known, Done: False
+        3 Stages known, Done: True
+        >>> spark.clearProgressHandlers()
+        """
+        raise PySparkRuntimeError(
+            error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
+            message_parameters={"feature": 
"SparkSession.registerProgressHandler"},
+        )
+
+    def removeProgressHandler(self, handler: "ProgressHandler") -> None:
+        """
+        Remove a progress handler that was previously registered.
+
+        .. versionadded:: 4.0
+
+        Parameters
+        ----------
+        handler : ProgressHandler
+          The handler to remove if present in the list of progress handlers.
+        """
+        raise PySparkRuntimeError(
+            error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
+            message_parameters={"feature": 
"SparkSession.removeProgressHandler"},
+        )
+
+    def clearProgressHandlers(self) -> None:
+        """
+        Clear all registered progress handlers.
+
+        .. versionadded:: 4.0
+        """
+        raise PySparkRuntimeError(
+            error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
+            message_parameters={"feature": 
"SparkSession.clearProgressHandlers"},
+        )
+
     def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
         """
         Copy file from local to cloud storage file system.
@@ -2194,6 +2250,11 @@ def _test() -> None:
 
     os.chdir(os.environ["SPARK_HOME"])
 
+    # Disable Doc Tests for Spark Connect only functions:
+    pyspark.sql.session.SparkSession.registerProgressHandler.__doc__ = None
+    pyspark.sql.session.SparkSession.removeProgressHandler.__doc__ = None
+    pyspark.sql.session.SparkSession.clearProgressHandlers.__doc__ = None
+
     globs = pyspark.sql.session.__dict__.copy()
     globs["spark"] = (
         SparkSession.builder.master("local[4]").appName("sql.session 
tests").getOrCreate()
diff --git a/python/pyspark/sql/tests/connect/shell/__init__.py 
b/python/pyspark/sql/tests/connect/shell/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/shell/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/connect/shell/test_progress.py 
b/python/pyspark/sql/tests/connect/shell/test_progress.py
new file mode 100644
index 000000000000..7d99a699eefa
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/shell/test_progress.py
@@ -0,0 +1,111 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from io import StringIO
+import unittest
+from typing import Iterable
+
+from pyspark.testing.connectutils import (
+    should_test_connect,
+    connect_requirement_message,
+)
+from pyspark.testing.utils import PySparkErrorTestUtils
+
+if should_test_connect:
+    from pyspark.sql.connect.shell.progress import Progress, StageInfo
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class ProgressBarTest(unittest.TestCase, PySparkErrorTestUtils):
+    def test_simple_progress(self):
+        stages = [StageInfo(0, 100, 50, 999, False)]
+        buffer = StringIO()
+        p = Progress(output=buffer, enabled=True)
+        p.update_ticks(stages, 10)
+        val = buffer.getvalue()
+        self.assertIn("50.00%", val, "Current progress is 50%")
+        self.assertIn("****", val, "Should use the default char to print.")
+        self.assertIn("Scanned 999.0 B", val, "Should contain the bytes 
scanned metric.")
+        self.assertFalse(val.endswith("\r"), "Line should not be empty")
+        p.finish()
+        val = buffer.getvalue()
+        self.assertTrue(val.endswith("\r"), "Line should be empty")
+
+    def test_configure_char(self):
+        stages = [StageInfo(0, 100, 50, 999, False)]
+        buffer = StringIO()
+        p = Progress(char="+", output=buffer, enabled=True)
+        p.update_ticks(stages, 10)
+        val = buffer.getvalue()
+        self.assertIn("++++++", val, "Updating the char works.")
+
+    def test_disabled_does_not_print(self):
+        stages = [StageInfo(0, 100, 50, 999, False)]
+        buffer = StringIO()
+        p = Progress(char="+", output=buffer, enabled=False)
+        p.update_ticks(stages, 10)
+        stages = [StageInfo(0, 100, 51, 999, False)]
+        p.update_ticks(stages, 10)
+        val = buffer.getvalue()
+        self.assertEqual(0, len(val), "If the printing is disabled, don't 
print.")
+
+    def test_finish_progress(self):
+        stages = [StageInfo(0, 100, 50, 999, False)]
+        buffer = StringIO()
+        p = Progress(char="+", output=buffer, enabled=True)
+        p.update_ticks(stages, 10)
+        p.finish()
+        self.assertTrue(buffer.getvalue().endswith("\r"), "Last line should be 
empty")
+
+    def test_progress_handler(self):
+        stages = [StageInfo(0, 0, 0, 0, False)]
+
+        handler_called = 0
+        done_called = False
+
+        def handler(stages: Iterable[StageInfo], inflight_tasks: int, done: 
bool):
+            nonlocal handler_called, done_called
+            handler_called = 1
+            self.assertEqual(100, sum(map(lambda x: x.num_tasks, stages)))
+            self.assertEqual(50, sum(map(lambda x: x.num_completed_tasks, 
stages)))
+            self.assertEqual(999, sum(map(lambda x: x.num_bytes_read, stages)))
+            self.assertEqual(10, inflight_tasks)
+            done_called = done
+
+        buffer = StringIO()
+        p = Progress(char="+", output=buffer, enabled=True, handlers=[handler])
+        p.update_ticks(stages, 1)
+        stages = [StageInfo(0, 100, 50, 999, False)]
+        p.update_ticks(stages, 10)
+        self.assertIn("++++++", buffer.getvalue(), "Updating the char works.")
+        self.assertEqual(1, handler_called, "Handler should be called.")
+        self.assertFalse(done_called, "Before finish, done should be False")
+        p.finish()
+        self.assertTrue(buffer.getvalue().endswith("\r"), "Last line should be 
empty")
+        self.assertTrue(done_called, "After finish, done should be True")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.shell.test_progress import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py 
b/python/pyspark/sql/tests/connect/test_connect_session.py
index bebe2cfc2923..b73a56340984 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -58,6 +58,27 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
     def tearDown(self):
         self.spark.stop()
 
+    def test_progress_handler(self):
+        handler_called = []
+
+        def handler(**kwargs):
+            nonlocal handler_called
+            handler_called.append(kwargs)
+
+        self.spark.registerProgressHandler(handler)
+        self.spark.sql("select 1").collect()
+        self.assertGreaterEqual(len(handler_called), 1)
+
+        handler_called = []
+        self.spark.removeProgressHandler(handler)
+        self.spark.sql("select 1").collect()
+        self.assertEqual(len(handler_called), 0)
+
+        self.spark.registerProgressHandler(handler)
+        self.spark.clearProgressHandlers()
+        self.spark.sql("select 1").collect()
+        self.assertGreaterEqual(len(handler_called), 0)
+
     def _check_no_active_session_error(self, e: PySparkException):
         self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", 
message_parameters=dict())
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to