This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 07ec264da9ed [SPARK-45813][CONNECT][PYTHON] Return the observed 
metrics from commands
07ec264da9ed is described below

commit 07ec264da9ed56c0de21ab60fff95bab64d3579e
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Wed Nov 15 10:49:22 2023 +0900

    [SPARK-45813][CONNECT][PYTHON] Return the observed metrics from commands
    
    ### What changes were proposed in this pull request?
    
    Returns the observed metrics from commands.
    
    ### Why are the changes needed?
    
    Currently the observed metrics on commands are not available.
    
    For example:
    
    ```py
    >>> df = spark.range(10)
    >>>
    >>> observation = Observation()
    >>> observed_df = df.observe(observation, count(lit(1)).alias("cnt"))
    >>>
    >>> observed_df.show()
    ...
    >>> observation.get
    {}
    ```
    
    it should be:
    
    ```py
    >>> observation.get
    {'cnt': 10}
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the observed metrics on commands will be available.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43690 from ueshin/issues/SPARK-45813/observed_metrics.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../connect/execution/ExecuteThreadRunner.scala    | 15 +++++++++
 .../execution/SparkConnectPlanExecution.scala      | 37 ++++++++++++++++------
 .../sql/connect/planner/SparkConnectPlanner.scala  | 13 ++++++--
 .../spark/sql/connect/service/ExecuteHolder.scala  |  7 ++++
 python/pyspark/sql/connect/client/core.py          | 14 +++++---
 .../sql/tests/connect/test_connect_basic.py        |  2 +-
 python/pyspark/sql/tests/test_dataframe.py         | 21 ++++++++++++
 .../scala/org/apache/spark/sql/Observation.scala   | 31 +++++++++++++-----
 8 files changed, 115 insertions(+), 25 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index ea2bbe0093fc..24b3c302b759 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -162,6 +162,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
             s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
       }
 
+      if (executeHolder.observations.nonEmpty) {
+        val observedMetrics = executeHolder.observations.map { case (name, 
observation) =>
+          val values = observation.getOrEmpty.map { case (key, value) =>
+            (Some(key), value)
+          }.toSeq
+          name -> values
+        }.toMap
+        executeHolder.responseObserver.onNext(
+          SparkConnectPlanExecution
+            .createObservedMetricsResponse(
+              executeHolder.sessionHolder.sessionId,
+              executeHolder.sessionHolder.serverSessionId,
+              observedMetrics))
+      }
+
       lock.synchronized {
         // Synchronized before sending ResultComplete, and up until completing 
the result stream
         // to prevent a situation in which a client of reattachable execution 
receives
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 002239aba96e..23390bf7aba8 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -66,9 +66,8 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     responseObserver.onNext(createSchemaResponse(request.getSessionId, 
dataframe.schema))
     processAsArrowBatches(dataframe, responseObserver, executeHolder)
     
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, 
dataframe))
-    if (dataframe.queryExecution.observedMetrics.nonEmpty) {
-      
responseObserver.onNext(createObservedMetricsResponse(request.getSessionId, 
dataframe))
-    }
+    createObservedMetricsResponse(request.getSessionId, dataframe).foreach(
+      responseObserver.onNext)
   }
 
   type Batch = (Array[Byte], Long)
@@ -245,15 +244,33 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
 
   private def createObservedMetricsResponse(
       sessionId: String,
-      dataframe: DataFrame): ExecutePlanResponse = {
-    val observedMetrics = dataframe.queryExecution.observedMetrics.map { case 
(name, row) =>
-      val cols = (0 until row.length).map(i => toLiteralProto(row(i)))
+      dataframe: DataFrame): Option[ExecutePlanResponse] = {
+    val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
+      case (name, row) if !executeHolder.observations.contains(name) =>
+        val values = (0 until row.length).map { i =>
+          (if (row.schema != null) Some(row.schema.fieldNames(i)) else None, 
row(i))
+        }
+        name -> values
+    }
+    if (observedMetrics.nonEmpty) {
+      Some(SparkConnectPlanExecution
+        .createObservedMetricsResponse(sessionId, 
sessionHolder.serverSessionId, observedMetrics))
+    } else None
+  }
+}
+
+object SparkConnectPlanExecution {
+  def createObservedMetricsResponse(
+      sessionId: String,
+      serverSessionId: String,
+      metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = 
{
+    val observedMetrics = metrics.map { case (name, values) =>
       val metrics = ExecutePlanResponse.ObservedMetrics
         .newBuilder()
         .setName(name)
-        .addAllValues(cols.asJava)
-      if (row.schema != null) {
-        metrics.addAllKeys(row.schema.fieldNames.toList.asJava)
+      values.foreach { case (key, value) =>
+        metrics.addValues(toLiteralProto(value))
+        key.foreach(metrics.addKeys)
       }
       metrics.build()
     }
@@ -261,7 +278,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     ExecutePlanResponse
       .newBuilder()
       .setSessionId(sessionId)
-      .setServerSideSessionId(sessionHolder.serverSessionId)
+      .setServerSideSessionId(serverSessionId)
       .addAllObservedMetrics(observedMetrics.asJava)
       .build()
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 654513857824..637ed09798a5 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -42,7 +42,7 @@ import 
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami
 import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{functions => MLFunctions}
-import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, 
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -1069,8 +1069,17 @@ class SparkConnectPlanner(
     val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
       Column(transformExpression(expr))
     }
+    val name = rel.getName
+    val input = transformRelation(rel.getInput)
 
-    CollectMetrics(rel.getName, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
+    if (input.isStreaming || executeHolderOpt.isEmpty) {
+      CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
+    } else {
+      val observation = Observation(name)
+      observation.register(session, planId)
+      executeHolderOpt.get.addObservation(name, observation)
+      CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
+    }
   }
 
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index eed8cc01f7c6..8b910154d2f4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._
 import org.apache.spark.{SparkEnv, SparkSQLException}
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Observation
 import org.apache.spark.sql.connect.common.ProtoUtils
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_ENABLED
 import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, 
ExecuteResponseObserver, ExecuteThreadRunner}
@@ -89,6 +90,8 @@ private[connect] class ExecuteHolder(
 
   val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new 
SystemClock())
 
+  val observations: mutable.Map[String, Observation] = mutable.Map.empty
+
   private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)
 
   /** System.currentTimeMillis when this ExecuteHolder was created. */
@@ -132,6 +135,10 @@ private[connect] class ExecuteHolder(
     runner.join()
   }
 
+  def addObservation(name: String, observation: Observation): Unit = 
synchronized {
+    observations += (name -> observation)
+  }
+
   /**
    * Attach an ExecuteGrpcResponseSender that will consume responses from the 
query and send them
    * out on the Grpc response stream. The sender will start from the start of 
the response stream.
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index b98de0f9ceea..a2590dec960d 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1176,10 +1176,16 @@ class SparkConnectClient(object):
                 logger.debug("Received observed metric batch.")
                 for observed_metrics in 
self._build_observed_metrics(b.observed_metrics):
                     if observed_metrics.name in observations:
-                        observations[observed_metrics.name]._result = {
-                            key: LiteralExpression._to_value(metric)
-                            for key, metric in zip(observed_metrics.keys, 
observed_metrics.metrics)
-                        }
+                        observation_result = 
observations[observed_metrics.name]._result
+                        assert observation_result is not None
+                        observation_result.update(
+                            {
+                                key: LiteralExpression._to_value(metric)
+                                for key, metric in zip(
+                                    observed_metrics.keys, 
observed_metrics.metrics
+                                )
+                            }
+                        )
                     yield observed_metrics
             if b.HasField("schema"):
                 logger.debug("Received the schema.")
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e926eb835a80..d2febcd6b089 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1824,7 +1824,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
 
         self.assert_eq(cdf, df)
 
-        self.assert_eq(cobservation.get, observation.get)
+        self.assertEquals(cobservation.get, observation.get)
 
         observed_metrics = cdf.attrs["observed_metrics"]
         self.assert_eq(len(observed_metrics), 1)
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 527cf702bce9..3b2fb87123eb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1078,6 +1078,27 @@ class DataFrameTestsMixin:
         self.assertEqual(observation1.get, dict(cnt=50))
         self.assertEqual(observation2.get, dict(cnt=100))
 
+    def test_observe_on_commands(self):
+        from pyspark.sql import Observation
+
+        df = self.spark.range(50)
+
+        test_table = "test_table"
+
+        # DataFrameWriter
+        with self.table(test_table):
+            for command, action in [
+                ("collect", lambda df: df.collect()),
+                ("show", lambda df: df.show(50)),
+                ("save", lambda df: 
df.write.format("noop").mode("overwrite").save()),
+                ("create", lambda df: 
df.writeTo(test_table).using("parquet").create()),
+            ]:
+                with self.subTest(command=command):
+                    observation = Observation()
+                    observed_df = df.observe(observation, 
count(lit(1)).alias("cnt"))
+                    action(observed_df)
+                    self.assertEqual(observation.get, dict(cnt=50))
+
     def test_sample(self):
         with self.assertRaises(PySparkTypeError) as pe:
             self.spark.range(1).sample()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index f4b518c1e9fb..104e7c101fd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -58,7 +58,7 @@ class Observation(val name: String) {
 
   private val listener: ObservationListener = ObservationListener(this)
 
-  @volatile private var ds: Option[Dataset[_]] = None
+  @volatile private var dataframeId: Option[(SparkSession, Long)] = None
 
   @volatile private var metrics: Option[Map[String, Any]] = None
 
@@ -79,7 +79,7 @@ class Observation(val name: String) {
         ". Please register a StreamingQueryListener and get the metric for 
each microbatch in " +
         "QueryProgressEvent.progress, or use query.lastProgress or 
query.recentProgress.")
     }
-    register(ds)
+    register(ds.sparkSession, ds.id)
     ds.observe(name, expr, exprs: _*)
   }
 
@@ -117,29 +117,44 @@ class Observation(val name: String) {
       get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava
   }
 
-  private def register(ds: Dataset[_]): Unit = {
+  /**
+   * Get the observed metrics. This returns the metrics if they are available, 
otherwise an empty.
+   *
+   * @return the observed metrics as a `Map[String, Any]`
+   */
+  @throws[InterruptedException]
+  private[sql] def getOrEmpty: Map[String, _] = {
+    synchronized {
+      if (metrics.isEmpty) {
+        wait(100) // Wait for 100ms to see if metrics are available
+      }
+      metrics.getOrElse(Map.empty)
+    }
+  }
+
+  private[sql] def register(sparkSession: SparkSession, dataframeId: Long): 
Unit = {
     // makes this class thread-safe:
     // only the first thread entering this block can set sparkSession
     // all other threads will see the exception, as it is only allowed to do 
this once
     synchronized {
-      if (this.ds.isDefined) {
+      if (this.dataframeId.isDefined) {
         throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
       }
-      this.ds = Some(ds)
+      this.dataframeId = Some((sparkSession, dataframeId))
     }
 
-    ds.sparkSession.listenerManager.register(this.listener)
+    sparkSession.listenerManager.register(this.listener)
   }
 
   private def unregister(): Unit = {
-    this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener))
+    this.dataframeId.foreach(_._1.listenerManager.unregister(this.listener))
   }
 
   private[spark] def onFinish(qe: QueryExecution): Unit = {
     synchronized {
       if (this.metrics.isEmpty && qe.logical.exists {
         case CollectMetrics(name, _, _, dataframeId) =>
-          name == this.name && dataframeId == ds.get.id
+          name == this.name && dataframeId == this.dataframeId.get._2
         case _ => false
       }) {
         val row = qe.observedMetrics.get(name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to