This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4a3fd8f7e69e [SPARK-47081][CONNECT][FOLLOW] Improving the usability of the Progress Handler 4a3fd8f7e69e is described below commit 4a3fd8f7e69e5d0cea52fae120348973bffbb738 Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Tue Apr 9 16:58:47 2024 +0900 [SPARK-47081][CONNECT][FOLLOW] Improving the usability of the Progress Handler ### What changes were proposed in this pull request? This patch improves the usability of the progress handler by making sure that an update to the client is sent on every wakeup interval from the server (and not only when a task is finished). The class managing the progress is now usable as a context manager and I've added the progress reporting to more RPC calls to the server. In addition, it adds the operation ID to the progress handler notify message so that the callback can differentiate between multiple concurrent queries. ```python def progress_handler(stages, inflight_tasks, operation_id): print(f"Operation {operation_id}: {inflight_tasks} inflight tasks") spark.registerProgressHandler(progress_handler) ``` ### Why are the changes needed? Usability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added Tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #45907 from grundprinzip/SPARK-47081_2. Lead-authored-by: Martin Grund <martin.gr...@databricks.com> Co-authored-by: Martin Grund <grundprin...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../scala/org/apache/spark/sql/SparkSession.scala | 5 +- .../ConnectProgressExecutionListener.scala | 10 ++-- .../execution/ExecuteGrpcResponseSender.scala | 9 ++-- .../ConnectProgressExecutionListenerSuite.scala | 12 ++--- .../connect/planner/SparkConnectServiceSuite.scala | 39 +++++++++------- python/pyspark/sql/connect/client/core.py | 54 +++++++++++----------- python/pyspark/sql/connect/shell/progress.py | 37 ++++++++++++--- .../sql/tests/connect/shell/test_progress.py | 31 ++++++++++++- 8 files changed, 131 insertions(+), 66 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1e467a864442..5a2d9bc44c9f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -567,8 +567,9 @@ class SparkSession private[sql] ( private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - // .toSeq forces that the iterator is consumed and closed - client.execute(plan).toSeq + // .toSeq forces that the iterator is consumed and closed. On top, ignore all + // progress messages. + client.execute(plan).filter(!_.hasExecutionProgress).toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala index 954956363505..a1881765a416 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala @@ -82,10 +82,14 @@ private[connect] class ConnectProgressExecutionListener extends SparkListener wi * * If the tracker was marked as dirty, the state is reset after. */ - def yieldWhenDirty(thunk: (Seq[StageInfo], Long) => Unit): Unit = { - if (dirty.get()) { + def yieldWhenDirty(force: Boolean = false)(thunk: (Seq[StageInfo], Long) => Unit): Unit = { + if (force) { thunk(stages.values.toSeq, inFlightTasks.get()) - dirty.set(false) + } else { + if (dirty.get()) { + thunk(stages.values.toSeq, inFlightTasks.get()) + dirty.set(false) + } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index a9444862b3aa..4b95f38c6695 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -139,7 +139,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( * send to the client about the current query progress. The message is not directly send to the * client, but rather enqueued to in the response observer. */ - private def enqueueProgressMessage(): Unit = { + private def enqueueProgressMessage(force: Boolean = false): Unit = { if (executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) > 0) { SparkConnectService.executionListener.foreach { listener => // It is possible, that the tracker is no longer available and in this @@ -147,7 +147,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( // having to synchronize on the listener. listener.tryGetTracker(executeHolder.jobTag).foreach { tracker => // Only send progress message if there is something new to report. - tracker.yieldWhenDirty { (stages, inflightTasks) => + tracker.yieldWhenDirty(force) { (stages, inflightTasks) => val response = ExecutePlanResponse .newBuilder() .setExecutionProgress( @@ -158,7 +158,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( .build() // There is a special case when the response observer has alreaady determined // that the final message is send (and the stream will be closed) but we might want - // to send the progress message. In this case we ignore the result of the `onNext` call. + // to send the progress message. In this case we ignore the result of the `onNext` + // call. executeHolder.responseObserver.tryOnNext(response) } } @@ -248,7 +249,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( } logTrace(s"Wait for response to become available with timeout=$timeout ms.") executionObserver.responseLock.wait(timeout) - enqueueProgressMessage() + enqueueProgressMessage(force = true) logTrace(s"Reacquired executionObserver lock after waiting.") sleepEnd = System.nanoTime() } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala index 43e978a18f1f..7c1b9362425d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala @@ -70,13 +70,13 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu listener.onJobStart(testJobStart) val t = listener.trackedTags(testTag) - t.yieldWhenDirty((stages, inflight) => { + t.yieldWhenDirty() { (stages, inflight) => assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 0) assert(stages.size == 2) assert(stages.map(_.inputBytesRead).sum == 0) assert(inflight == 0) - }) + } } test("taskDone") { @@ -96,7 +96,7 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu val t = listener.trackedTags(testTag) var yielded = false - t.yieldWhenDirty { (stages, inflight) => + t.yieldWhenDirty() { (stages, inflight) => assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 0) assert(stages.size == 2) @@ -113,7 +113,7 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu yielded = false listener.onTaskEnd(taskEnd) - t.yieldWhenDirty { (stages, inflight) => + t.yieldWhenDirty() { (stages, inflight) => assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 1) assert(stages.size == 2) @@ -129,14 +129,14 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu } assert(yielded, "Must updated with results") yielded = false - t.yieldWhenDirty { (stages, inflight) => + t.yieldWhenDirty() { (stages, inflight) => yielded = true } assert(!yielded, "Must not update if not dirty") val stageEnd = SparkListenerStageCompleted(testStage1) listener.onStageCompleted(stageEnd) - t.yieldWhenDirty { (stages, inflight) => + t.yieldWhenDirty() { (stages, inflight) => assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 1) assert(stages.size == 2) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index dafcaa9e0225..63cebd452364 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -197,15 +197,16 @@ class SparkConnectServiceSuite // The current implementation is expected to be blocking. This is here to make sure it is. assert(done) - // 4 Partitions + Metrics - assert(responses.size == 6) + // 4 Partitions + Metrics + optional progress messages + val filteredResponses = responses.filter(!_.hasExecutionProgress) + assert(filteredResponses.size == 6) // Make sure the first response is schema only - val head = responses.head + val head = filteredResponses.head assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) // Make sure the last response is metrics only - val last = responses.last + val last = filteredResponses.last assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) val allocator = new RootAllocator() @@ -213,7 +214,7 @@ class SparkConnectServiceSuite // Check the 'data' batches var expectedId = 0L var previousEId = 0.0d - responses.tail.dropRight(1).foreach { response => + filteredResponses.tail.dropRight(1).foreach { response => assert(response.hasArrowBatch) val batch = response.getArrowBatch assert(batch.getData != null) @@ -298,14 +299,15 @@ class SparkConnectServiceSuite assert(done) // 1 Partitions + Metrics - assert(responses.size == 3) + val filteredResponses = responses.filter(!_.hasExecutionProgress) + assert(filteredResponses.size == 3) // Make sure the first response is schema only - val head = responses.head + val head = filteredResponses.head assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) // Make sure the last response is metrics only - val last = responses.last + val last = filteredResponses.last assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) } } @@ -353,12 +355,13 @@ class SparkConnectServiceSuite assert(done) // 1 schema + 1 metric + at least 2 data batches - assert(responses.size > 3) + val filteredResponses = responses.filter(!_.hasExecutionProgress) + assert(filteredResponses.size > 3) val allocator = new RootAllocator() // Check the 'data' batches - responses.tail.dropRight(1).foreach { response => + filteredResponses.tail.dropRight(1).foreach { response => assert(response.hasArrowBatch) val batch = response.getArrowBatch assert(batch.getData != null) @@ -533,15 +536,16 @@ class SparkConnectServiceSuite assert(done) // Result + Metrics - if (responses.size > 1) { - assert(responses.size == 2) + val filteredResponses = responses.filter(!_.hasExecutionProgress) + if (filteredResponses.size > 1) { + assert(filteredResponses.size == 2) // Make sure the first response result only - val head = responses.head + val head = filteredResponses.head assert(head.hasSqlCommandResult && !head.hasMetrics) // Make sure the last response is metrics only - val last = responses.last + val last = filteredResponses.last assert(last.hasMetrics && !last.hasSqlCommandResult) } } @@ -786,14 +790,15 @@ class SparkConnectServiceSuite // The current implementation is expected to be blocking. This is here to make sure it is. assert(done) - assert(responses.size == 7) + val filteredResponses = responses.filter(!_.hasExecutionProgress) + assert(filteredResponses.size == 7) // Make sure the first response is schema only - val head = responses.head + val head = filteredResponses.head assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) // Make sure the last response is observed metrics only - val last = responses.last + val last = filteredResponses.last assert(last.getObservedMetricsCount == 1 && !last.hasSchema && !last.hasArrowBatch) val observedMetricsList = last.getObservedMetricsList.asScala diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index b731960bbaf3..532d490d925e 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -897,11 +897,12 @@ class SparkConnectClient(object): logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - for response in self._execute_and_fetch_as_iterator(req, observations): - if isinstance(response, StructType): - yield response - elif isinstance(response, pa.RecordBatch): - yield pa.Table.from_batches([response]) + with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress: + for response in self._execute_and_fetch_as_iterator(req, observations, progress): + if isinstance(response, StructType): + yield response + elif isinstance(response, pa.RecordBatch): + yield pa.Table.from_batches([response]) def to_table( self, plan: pb2.Plan, observations: Dict[str, Observation] @@ -1331,7 +1332,7 @@ class SparkConnectClient(object): if b.HasField("execution_progress"): if progress: p = from_proto(b.execution_progress) - progress.update_ticks(*p) + progress.update_ticks(*p, operation_id=b.operation_id) if b.HasField("arrow_batch"): logger.debug( f"Received arrow batch rows={b.arrow_batch.row_count} " @@ -1411,26 +1412,27 @@ class SparkConnectClient(object): schema: Optional[StructType] = None properties: Dict[str, Any] = {} - progress = Progress(handlers=self._progress_handlers) - for response in self._execute_and_fetch_as_iterator(req, observations, progress=progress): - if isinstance(response, StructType): - schema = response - elif isinstance(response, pa.RecordBatch): - batches.append(response) - elif isinstance(response, PlanMetrics): - metrics.append(response) - elif isinstance(response, PlanObservedMetrics): - observed_metrics.append(response) - elif isinstance(response, dict): - properties.update(**response) - else: - raise PySparkValueError( - error_class="UNKNOWN_RESPONSE", - message_parameters={ - "response": response, - }, - ) - progress.finish() + with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress: + for response in self._execute_and_fetch_as_iterator( + req, observations, progress=progress + ): + if isinstance(response, StructType): + schema = response + elif isinstance(response, pa.RecordBatch): + batches.append(response) + elif isinstance(response, PlanMetrics): + metrics.append(response) + elif isinstance(response, PlanObservedMetrics): + observed_metrics.append(response) + elif isinstance(response, dict): + properties.update(**response) + else: + raise PySparkValueError( + error_class="UNKNOWN_RESPONSE", + message_parameters={ + "response": response, + }, + ) if len(batches) > 0: if self_destruct: diff --git a/python/pyspark/sql/connect/shell/progress.py b/python/pyspark/sql/connect/shell/progress.py index 8a8064c29cdc..52b34a924900 100644 --- a/python/pyspark/sql/connect/shell/progress.py +++ b/python/pyspark/sql/connect/shell/progress.py @@ -21,6 +21,7 @@ from dataclasses import dataclass import time import sys import typing +from types import TracebackType from typing import Iterable, Any from pyspark.sql.connect.proto import ExecutePlanResponse @@ -51,6 +52,7 @@ class ProgressHandler(abc.ABC): self, stages: typing.Optional[Iterable[StageInfo]], inflight_tasks: int, + operation_id: typing.Optional[str], done: bool, ) -> None: pass @@ -88,6 +90,7 @@ class Progress: output: typing.IO = sys.stdout, enabled: bool = False, handlers: Iterable[ProgressHandler] = [], + operation_id: typing.Optional[str] = None, ) -> None: """ Constructs a new Progress bar. The progress bar is typically used in @@ -107,8 +110,8 @@ class Progress: handlers : list of ProgressHandler A list of handlers that will be called when the progress bar is updated. """ - self._ticks = 0 - self._tick = 0 + self._ticks: typing.Optional[int] = None + self._tick: typing.Optional[int] = None x, y = get_terminal_size() self._min_width = min_width self._char = char @@ -121,16 +124,35 @@ class Progress: self._running = 0 self._handlers = handlers self._stages: Iterable[StageInfo] = [] + self._operation_id = operation_id def _notify(self, done: bool = False) -> None: for handler in self._handlers: handler( stages=self._stages, inflight_tasks=self._running, + operation_id=self._operation_id, done=done, ) - def update_ticks(self, stages: Iterable[StageInfo], inflight_tasks: int) -> None: + def __enter__(self) -> "Progress": + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exception: typing.Optional[BaseException], + exc_tb: typing.Optional[TracebackType], + ) -> typing.Any: + self.finish() + return False + + def update_ticks( + self, + stages: Iterable[StageInfo], + inflight_tasks: int, + operation_id: typing.Optional[str] = None, + ) -> None: """This method is called from the execution to update the progress bar with a new total tick counter and the current position. This is necessary in case new stages get added with new tasks and so the total task number will be updated as well. @@ -142,13 +164,16 @@ class Progress: inflight_tasks : int The number of tasks that are currently running. """ + if self._operation_id is None or len(self._operation_id) == 0: + self._operation_id = operation_id + total_tasks = sum(map(lambda x: x.num_tasks, stages)) completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages)) - if total_tasks > 0 and completed_tasks != self._tick: + if total_tasks > 0: self._ticks = total_tasks self._tick = completed_tasks self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages)) - if self._tick > 0: + if self._tick is not None and self._tick >= 0: self.output() self._running = inflight_tasks self._stages = stages @@ -163,7 +188,7 @@ class Progress: def output(self) -> None: """Writes the progress bar out.""" - if self._enabled: + if self._enabled and self._tick is not None and self._ticks is not None: val = int((self._tick / float(self._ticks)) * self._width) bar = self._char * val + "-" * (self._width - val) percent_complete = (self._tick / self._ticks) * 100 diff --git a/python/pyspark/sql/tests/connect/shell/test_progress.py b/python/pyspark/sql/tests/connect/shell/test_progress.py index 7d99a699eefa..80c77b467070 100644 --- a/python/pyspark/sql/tests/connect/shell/test_progress.py +++ b/python/pyspark/sql/tests/connect/shell/test_progress.py @@ -19,6 +19,7 @@ from io import StringIO import unittest from typing import Iterable +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase from pyspark.testing.connectutils import ( should_test_connect, connect_requirement_message, @@ -77,17 +78,22 @@ class ProgressBarTest(unittest.TestCase, PySparkErrorTestUtils): handler_called = 0 done_called = False - def handler(stages: Iterable[StageInfo], inflight_tasks: int, done: bool): + def handler( + stages: Iterable[StageInfo], inflight_tasks: int, operation_id: str, done: bool + ): nonlocal handler_called, done_called handler_called = 1 self.assertEqual(100, sum(map(lambda x: x.num_tasks, stages))) self.assertEqual(50, sum(map(lambda x: x.num_completed_tasks, stages))) self.assertEqual(999, sum(map(lambda x: x.num_bytes_read, stages))) self.assertEqual(10, inflight_tasks) + self.assertEqual(operation_id, "operation_id") done_called = done buffer = StringIO() - p = Progress(char="+", output=buffer, enabled=True, handlers=[handler]) + p = Progress( + char="+", output=buffer, enabled=True, handlers=[handler], operation_id="operation_id" + ) p.update_ticks(stages, 1) stages = [StageInfo(0, 100, 50, 999, False)] p.update_ticks(stages, 10) @@ -99,6 +105,27 @@ class ProgressBarTest(unittest.TestCase, PySparkErrorTestUtils): self.assertTrue(done_called, "After finish, done should be True") +class SparkConnectProgressHandlerE2E(SparkConnectSQLTestCase): + def test_custom_handler_works(self): + called = False + + def handler(**kwargs): + nonlocal called + called = True + self.assertIsNotNone(kwargs.get("stages")) + self.assertIsNotNone(kwargs.get("operation_id")) + self.assertIsNotNone(kwargs.get("inflight_tasks")) + self.assertGreater(len(kwargs.get("stages")), 0) + self.assertGreater(len(kwargs.get("operation_id")), 0) + + try: + self.connect.registerProgressHandler(handler) + self.connect.range(100).repartition(20).count() + self.assertTrue(called, "Handler must have been called") + finally: + self.connect.clearProgressHandlers() + + if __name__ == "__main__": from pyspark.sql.tests.connect.shell.test_progress import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org