This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 07ec264da9ed [SPARK-45813][CONNECT][PYTHON] Return the observed metrics from commands 07ec264da9ed is described below commit 07ec264da9ed56c0de21ab60fff95bab64d3579e Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Nov 15 10:49:22 2023 +0900 [SPARK-45813][CONNECT][PYTHON] Return the observed metrics from commands ### What changes were proposed in this pull request? Returns the observed metrics from commands. ### Why are the changes needed? Currently the observed metrics on commands are not available. For example: ```py >>> df = spark.range(10) >>> >>> observation = Observation() >>> observed_df = df.observe(observation, count(lit(1)).alias("cnt")) >>> >>> observed_df.show() ... >>> observation.get {} ``` it should be: ```py >>> observation.get {'cnt': 10} ``` ### Does this PR introduce _any_ user-facing change? Yes, the observed metrics on commands will be available. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43690 from ueshin/issues/SPARK-45813/observed_metrics. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../connect/execution/ExecuteThreadRunner.scala | 15 +++++++++ .../execution/SparkConnectPlanExecution.scala | 37 ++++++++++++++++------ .../sql/connect/planner/SparkConnectPlanner.scala | 13 ++++++-- .../spark/sql/connect/service/ExecuteHolder.scala | 7 ++++ python/pyspark/sql/connect/client/core.py | 14 +++++--- .../sql/tests/connect/test_connect_basic.py | 2 +- python/pyspark/sql/tests/test_dataframe.py | 21 ++++++++++++ .../scala/org/apache/spark/sql/Observation.scala | 31 +++++++++++++----- 8 files changed, 115 insertions(+), 25 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index ea2bbe0093fc..24b3c302b759 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -162,6 +162,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") } + if (executeHolder.observations.nonEmpty) { + val observedMetrics = executeHolder.observations.map { case (name, observation) => + val values = observation.getOrEmpty.map { case (key, value) => + (Some(key), value) + }.toSeq + name -> values + }.toMap + executeHolder.responseObserver.onNext( + SparkConnectPlanExecution + .createObservedMetricsResponse( + executeHolder.sessionHolder.sessionId, + executeHolder.sessionHolder.serverSessionId, + observedMetrics)) + } + lock.synchronized { // Synchronized before sending ResultComplete, and up until completing the result stream // to prevent a situation in which a client of reattachable execution receives diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 002239aba96e..23390bf7aba8 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -66,9 +66,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(dataframe, responseObserver, executeHolder) responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) - if (dataframe.queryExecution.observedMetrics.nonEmpty) { - responseObserver.onNext(createObservedMetricsResponse(request.getSessionId, dataframe)) - } + createObservedMetricsResponse(request.getSessionId, dataframe).foreach( + responseObserver.onNext) } type Batch = (Array[Byte], Long) @@ -245,15 +244,33 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) private def createObservedMetricsResponse( sessionId: String, - dataframe: DataFrame): ExecutePlanResponse = { - val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) => - val cols = (0 until row.length).map(i => toLiteralProto(row(i))) + dataframe: DataFrame): Option[ExecutePlanResponse] = { + val observedMetrics = dataframe.queryExecution.observedMetrics.collect { + case (name, row) if !executeHolder.observations.contains(name) => + val values = (0 until row.length).map { i => + (if (row.schema != null) Some(row.schema.fieldNames(i)) else None, row(i)) + } + name -> values + } + if (observedMetrics.nonEmpty) { + Some(SparkConnectPlanExecution + .createObservedMetricsResponse(sessionId, sessionHolder.serverSessionId, observedMetrics)) + } else None + } +} + +object SparkConnectPlanExecution { + def createObservedMetricsResponse( + sessionId: String, + serverSessionId: String, + metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = { + val observedMetrics = metrics.map { case (name, values) => val metrics = ExecutePlanResponse.ObservedMetrics .newBuilder() .setName(name) - .addAllValues(cols.asJava) - if (row.schema != null) { - metrics.addAllKeys(row.schema.fieldNames.toList.asJava) + values.foreach { case (key, value) => + metrics.addValues(toLiteralProto(value)) + key.foreach(metrics.addKeys) } metrics.build() } @@ -261,7 +278,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) ExecutePlanResponse .newBuilder() .setSessionId(sessionId) - .setServerSideSessionId(sessionHolder.serverSessionId) + .setServerSideSessionId(serverSessionId) .addAllObservedMetrics(observedMetrics.asJava) .build() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 654513857824..637ed09798a5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -42,7 +42,7 @@ import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase import org.apache.spark.internal.Logging import org.apache.spark.ml.{functions => MLFunctions} -import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, RelationalGroupedDataset, SparkSession} +import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -1069,8 +1069,17 @@ class SparkConnectPlanner( val metrics = rel.getMetricsList.asScala.toSeq.map { expr => Column(transformExpression(expr)) } + val name = rel.getName + val input = transformRelation(rel.getInput) - CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput), planId) + if (input.isStreaming || executeHolderOpt.isEmpty) { + CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId) + } else { + val observation = Observation(name) + observation.register(session, planId) + executeHolderOpt.get.addObservation(name, observation) + CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId) + } } private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index eed8cc01f7c6..8b910154d2f4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.internal.Logging +import org.apache.spark.sql.Observation import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_ENABLED import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} @@ -89,6 +90,8 @@ private[connect] class ExecuteHolder( val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new SystemClock()) + val observations: mutable.Map[String, Observation] = mutable.Map.empty + private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this) /** System.currentTimeMillis when this ExecuteHolder was created. */ @@ -132,6 +135,10 @@ private[connect] class ExecuteHolder( runner.join() } + def addObservation(name: String, observation: Observation): Unit = synchronized { + observations += (name -> observation) + } + /** * Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them * out on the Grpc response stream. The sender will start from the start of the response stream. diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index b98de0f9ceea..a2590dec960d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1176,10 +1176,16 @@ class SparkConnectClient(object): logger.debug("Received observed metric batch.") for observed_metrics in self._build_observed_metrics(b.observed_metrics): if observed_metrics.name in observations: - observations[observed_metrics.name]._result = { - key: LiteralExpression._to_value(metric) - for key, metric in zip(observed_metrics.keys, observed_metrics.metrics) - } + observation_result = observations[observed_metrics.name]._result + assert observation_result is not None + observation_result.update( + { + key: LiteralExpression._to_value(metric) + for key, metric in zip( + observed_metrics.keys, observed_metrics.metrics + ) + } + ) yield observed_metrics if b.HasField("schema"): logger.debug("Received the schema.") diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index e926eb835a80..d2febcd6b089 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1824,7 +1824,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assert_eq(cdf, df) - self.assert_eq(cobservation.get, observation.get) + self.assertEquals(cobservation.get, observation.get) observed_metrics = cdf.attrs["observed_metrics"] self.assert_eq(len(observed_metrics), 1) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 527cf702bce9..3b2fb87123eb 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1078,6 +1078,27 @@ class DataFrameTestsMixin: self.assertEqual(observation1.get, dict(cnt=50)) self.assertEqual(observation2.get, dict(cnt=100)) + def test_observe_on_commands(self): + from pyspark.sql import Observation + + df = self.spark.range(50) + + test_table = "test_table" + + # DataFrameWriter + with self.table(test_table): + for command, action in [ + ("collect", lambda df: df.collect()), + ("show", lambda df: df.show(50)), + ("save", lambda df: df.write.format("noop").mode("overwrite").save()), + ("create", lambda df: df.writeTo(test_table).using("parquet").create()), + ]: + with self.subTest(command=command): + observation = Observation() + observed_df = df.observe(observation, count(lit(1)).alias("cnt")) + action(observed_df) + self.assertEqual(observation.get, dict(cnt=50)) + def test_sample(self): with self.assertRaises(PySparkTypeError) as pe: self.spark.range(1).sample() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala index f4b518c1e9fb..104e7c101fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -58,7 +58,7 @@ class Observation(val name: String) { private val listener: ObservationListener = ObservationListener(this) - @volatile private var ds: Option[Dataset[_]] = None + @volatile private var dataframeId: Option[(SparkSession, Long)] = None @volatile private var metrics: Option[Map[String, Any]] = None @@ -79,7 +79,7 @@ class Observation(val name: String) { ". Please register a StreamingQueryListener and get the metric for each microbatch in " + "QueryProgressEvent.progress, or use query.lastProgress or query.recentProgress.") } - register(ds) + register(ds.sparkSession, ds.id) ds.observe(name, expr, exprs: _*) } @@ -117,29 +117,44 @@ class Observation(val name: String) { get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava } - private def register(ds: Dataset[_]): Unit = { + /** + * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. + * + * @return the observed metrics as a `Map[String, Any]` + */ + @throws[InterruptedException] + private[sql] def getOrEmpty: Map[String, _] = { + synchronized { + if (metrics.isEmpty) { + wait(100) // Wait for 100ms to see if metrics are available + } + metrics.getOrElse(Map.empty) + } + } + + private[sql] def register(sparkSession: SparkSession, dataframeId: Long): Unit = { // makes this class thread-safe: // only the first thread entering this block can set sparkSession // all other threads will see the exception, as it is only allowed to do this once synchronized { - if (this.ds.isDefined) { + if (this.dataframeId.isDefined) { throw new IllegalArgumentException("An Observation can be used with a Dataset only once") } - this.ds = Some(ds) + this.dataframeId = Some((sparkSession, dataframeId)) } - ds.sparkSession.listenerManager.register(this.listener) + sparkSession.listenerManager.register(this.listener) } private def unregister(): Unit = { - this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener)) + this.dataframeId.foreach(_._1.listenerManager.unregister(this.listener)) } private[spark] def onFinish(qe: QueryExecution): Unit = { synchronized { if (this.metrics.isEmpty && qe.logical.exists { case CollectMetrics(name, _, _, dataframeId) => - name == this.name && dataframeId == ds.get.id + name == this.name && dataframeId == this.dataframeId.get._2 case _ => false }) { val row = qe.observedMetrics.get(name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org