This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 21548a8cc5c5 [SPARK-47545][CONNECT] Dataset `observe` support for the 
Scala client
21548a8cc5c5 is described below

commit 21548a8cc5c527d4416a276a852f967b4410bd4b
Author: Paddy Xu <xupa...@gmail.com>
AuthorDate: Wed May 8 15:44:02 2024 -0400

    [SPARK-47545][CONNECT] Dataset `observe` support for the Scala client
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for `Dataset.observe` to the Spark Connect Scala 
client. Note that the support here does not include listener support as it runs 
on the serve side.
    
    This PR includes a small refactoring to the `Observation` helper class. We 
extracted methods that are not bound to the SparkSession to `spark-api`, and 
added two subclasses on both `spark-core` and `spark-jvm-client`.
    
    ### Why are the changes needed?
    
    Before this PR, the `DF.observe` method is only supported in the Python 
client.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The user can now issue `DF.observe(name, metrics...)` or 
`DF.observe(observationObject, metrics...)` to get stats of columns of a 
dataframe.
    
    ### How was this patch tested?
    
    Added new e2e tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Nope.
    
    Closes #45701 from xupefei/scala-observe.
    
    Authored-by: Paddy Xu <xupa...@gmail.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  63 ++++++-
 .../scala/org/apache/spark/sql/Observation.scala   |  46 +++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |  31 +++-
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  43 +++++
 .../CheckConnectJvmClientCompatibility.scala       |   3 -
 .../src/main/protobuf/spark/connect/base.proto     |   1 +
 .../spark/sql/connect/client/SparkResult.scala     |  44 ++++-
 .../common/LiteralValueProtoConverter.scala        |   2 +-
 .../connect/execution/ExecuteThreadRunner.scala    |   1 +
 .../execution/SparkConnectPlanExecution.scala      |  12 +-
 python/pyspark/sql/connect/proto/base_pb2.py       | 188 ++++++++++-----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |   5 +-
 .../org/apache/spark/sql/ObservationBase.scala     | 113 +++++++++++++
 .../scala/org/apache/spark/sql/Observation.scala   |  62 +------
 14 files changed, 448 insertions(+), 166 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9a42afebf8f2..37f770319b69 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3337,8 +3337,69 @@ class Dataset[T] private[sql] (
     }
   }
 
+  /**
+   * Define (named) metrics to observe on the Dataset. This method returns an 
'observed' Dataset
+   * that returns the same result as the input, with the following guarantees: 
<ul> <li>It will
+   * compute the defined aggregates (metrics) on all the data that is flowing 
through the Dataset
+   * at that point.</li> <li>It will report the value of the defined aggregate 
columns as soon as
+   * we reach a completion point. A completion point is currently defined as 
the end of a
+   * query.</li> </ul> Please note that continuous execution is currently not 
supported.
+   *
+   * The metrics columns must either contain a literal (e.g. lit(42)), or 
should contain one or
+   * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). 
Expressions that
+   * contain references to the input Dataset's columns must always be wrapped 
in an aggregate
+   * function.
+   *
+   * A user can retrieve the metrics by calling
+   * `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`.
+   *
+   * {{{
+   *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
+   *   val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), 
max($"id").as("maxid"))
+   *   observed_ds.write.parquet("ds.parquet")
+   *   val metrics = observed_ds.collectResult().getObservedMetrics
+   * }}}
+   *
+   * @group typedrel
+   * @since 4.0.0
+   */
+  @scala.annotation.varargs
   def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
-    throw new UnsupportedOperationException("observe is not implemented.")
+    sparkSession.newDataset(agnosticEncoder) { builder =>
+      builder.getCollectMetricsBuilder
+        .setInput(plan.getRoot)
+        .setName(name)
+        .addAllMetrics((expr +: exprs).map(_.expr).asJava)
+    }
+  }
+
+  /**
+   * Observe (named) metrics through an `org.apache.spark.sql.Observation` 
instance. This is
+   * equivalent to calling `observe(String, Column, Column*)` but does not 
require to collect all
+   * results before returning the metrics - the metrics are filled during 
iterating the results,
+   * as soon as they are available. This method does not support streaming 
datasets.
+   *
+   * A user can retrieve the metrics by accessing 
`org.apache.spark.sql.Observation.get`.
+   *
+   * {{{
+   *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
+   *   val observation = Observation("my_metrics")
+   *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
+   *   observed_ds.write.parquet("ds.parquet")
+   *   val metrics = observation.get
+   * }}}
+   *
+   * @throws IllegalArgumentException
+   *   If this is a streaming Dataset (this.isStreaming == true)
+   *
+   * @group typedrel
+   * @since 4.0.0
+   */
+  @scala.annotation.varargs
+  def observe(observation: Observation, expr: Column, exprs: Column*): 
Dataset[T] = {
+    val df = observe(observation.name, expr, exprs: _*)
+    sparkSession.registerObservation(df.getPlanId.get, observation)
+    df
   }
 
   def checkpoint(): Dataset[T] = {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
new file mode 100644
index 000000000000..75629b6000f9
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.UUID
+
+class Observation(name: String) extends ObservationBase(name) {
+
+  /**
+   * Create an Observation instance without providing a name. This generates a 
random name.
+   */
+  def this() = this(UUID.randomUUID().toString)
+}
+
+/**
+ * (Scala-specific) Create instances of Observation via Scala `apply`.
+ * @since 4.0.0
+ */
+object Observation {
+
+  /**
+   * Observation constructor for creating an anonymous observation.
+   */
+  def apply(): Observation = new Observation()
+
+  /**
+   * Observation constructor for creating a named observation.
+   */
+  def apply(name: String): Observation = new Observation(name)
+
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 22bb62803fac..1188fba60a2f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql
 
 import java.io.Closeable
 import java.net.URI
+import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.TimeUnit._
 import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
 
@@ -36,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog
 import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, 
UnboundRowEncoder}
-import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, 
SparkResult}
+import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, 
SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.functions.lit
@@ -80,6 +81,8 @@ class SparkSession private[sql] (
     
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
   }
 
+  private[sql] val observationRegistry = new ConcurrentHashMap[Long, 
Observation]()
+
   /**
    * Runtime configuration interface for Spark.
    *
@@ -532,8 +535,12 @@ class SparkSession private[sql] (
 
   private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): 
SparkResult[T] = {
     val value = client.execute(plan)
-    val result = new SparkResult(value, allocator, encoder, timeZoneId)
-    result
+    new SparkResult(
+      value,
+      allocator,
+      encoder,
+      timeZoneId,
+      Some(setMetricsAndUnregisterObservation))
   }
 
   private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
@@ -554,6 +561,9 @@ class SparkSession private[sql] (
     client.execute(plan).filter(!_.hasExecutionProgress).toSeq
   }
 
+  private[sql] def execute(plan: proto.Plan): 
CloseableIterator[ExecutePlanResponse] =
+    client.execute(plan)
+
   private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): 
Unit = {
     val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
     execute(command)
@@ -779,6 +789,21 @@ class SparkSession private[sql] (
    * Set to false to prevent client.releaseSession on close() (testing only)
    */
   private[sql] var releaseSessionOnClose = true
+
+  private[sql] def registerObservation(planId: Long, observation: 
Observation): Unit = {
+    if (observationRegistry.putIfAbsent(planId, observation) != null) {
+      throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
+    }
+  }
+
+  private[sql] def setMetricsAndUnregisterObservation(
+      planId: Long,
+      metrics: Map[String, Any]): Unit = {
+    val observationOrNull = observationRegistry.remove(planId)
+    if (observationOrNull != null) {
+      observationOrNull.setMetricsAndNotify(Some(metrics))
+    }
+  }
 }
 
 // The minimal builder needed to create a spark session.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index a0729adb8960..73a2f6d4f88e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -22,6 +22,8 @@ import java.time.DateTimeException
 import java.util.Properties
 
 import scala.collection.mutable
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.DurationInt
 import scala.jdk.CollectionConverters._
 
 import org.apache.commons.io.FileUtils
@@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf
 import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, 
SQLHelper}
 import org.apache.spark.sql.test.SparkConnectServerUtils.port
 import org.apache.spark.sql.types._
+import org.apache.spark.util.SparkThreadUtils
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with 
PrivateMethodTester {
 
@@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
         (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
       }
   }
+
+  test("Observable metrics") {
+    val df = spark.range(99).withColumn("extra", col("id") - 1)
+    val ob1 = new Observation("ob1")
+    val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
+    val observedObservedDf = observedDf.observe("ob2", min("extra"), 
avg("extra"), max("extra"))
+
+    val ob1Schema = new StructType()
+      .add("min(id)", LongType)
+      .add("avg(id)", DoubleType)
+      .add("max(id)", LongType)
+    val ob2Schema = new StructType()
+      .add("min(extra)", LongType)
+      .add("avg(extra)", DoubleType)
+      .add("max(extra)", LongType)
+    val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), 
ob1Schema))
+    val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), 
ob2Schema))
+
+    assert(df.collectResult().getObservedMetrics === Map.empty)
+    assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
+    assert(observedObservedDf.collectResult().getObservedMetrics === 
ob1Metrics ++ ob2Metrics)
+  }
+
+  test("Observation.get is blocked until the query is finished") {
+    val df = spark.range(99).withColumn("extra", col("id") - 1)
+    val observation = new Observation("ob1")
+    val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))
+
+    // Start a new thread to get the observation
+    val future = Future(observation.get)(ExecutionContext.global)
+    // make sure the thread is blocked right now
+    val e = intercept[java.util.concurrent.TimeoutException] {
+      SparkThreadUtils.awaitResult(future, 2.seconds)
+    }
+    assert(e.getMessage.contains("Future timed out"))
+    observedDf.collect()
+    // make sure the thread is unblocked after the query is finished
+    val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
+    assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index c89dba03ed69..7be5e2ecd172 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -196,9 +196,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 49a33d3419b6..77dda277602a 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -434,6 +434,7 @@ message ExecutePlanResponse {
     string name = 1;
     repeated Expression.Literal values = 2;
     repeated string keys = 3;
+    int64 plan_id = 4;
   }
 
   message ResultComplete {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 93d1075aea02..0905ee76c3f3 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -27,10 +27,13 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, 
ArrowRecordBatch}
 import org.apache.arrow.vector.types.pojo
 
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, 
UnboundRowEncoder}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, 
ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
-import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
LiteralValueProtoConverter}
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.sql.util.ArrowUtils
 
@@ -38,7 +41,8 @@ private[sql] class SparkResult[T](
     responses: CloseableIterator[proto.ExecutePlanResponse],
     allocator: BufferAllocator,
     encoder: AgnosticEncoder[T],
-    timeZoneId: String)
+    timeZoneId: String,
+    setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
     extends AutoCloseable { self =>
 
   case class StageInfo(
@@ -79,6 +83,7 @@ private[sql] class SparkResult[T](
   private[this] var arrowSchema: pojo.Schema = _
   private[this] var nextResultIndex: Int = 0
   private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
+  private val observedMetrics = mutable.Map.empty[String, Row]
   private val cleanable =
     SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, 
responses))
 
@@ -117,6 +122,9 @@ private[sql] class SparkResult[T](
     while (!stop && responses.hasNext) {
       val response = responses.next()
 
+      // Collect metrics for this response
+      observedMetrics ++= 
processObservedMetrics(response.getObservedMetricsList)
+
       // Save and validate operationId
       if (opId == null) {
         opId = response.getOperationId
@@ -198,6 +206,29 @@ private[sql] class SparkResult[T](
     nonEmpty
   }
 
+  private def processObservedMetrics(
+      metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
+    metrics.asScala.map { metric =>
+      assert(metric.getKeysCount == metric.getValuesCount)
+      var schema = new StructType()
+      val keys = mutable.ListBuffer.empty[String]
+      val values = mutable.ListBuffer.empty[Any]
+      (0 until metric.getKeysCount).map { i =>
+        val key = metric.getKeys(i)
+        val value = 
LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
+        schema = schema.add(key, 
LiteralValueProtoConverter.toDataType(value.getClass))
+        keys += key
+        values += value
+      }
+      // If the metrics is registered by an Observation object, attach them 
and unblock any
+      // blocked thread.
+      setObservationMetricsOpt.foreach { setObservationMetrics =>
+        setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
+      }
+      metric.getName -> new GenericRowWithSchema(values.toArray, schema)
+    }
+  }
+
   /**
    * Returns the number of elements in the result.
    */
@@ -248,6 +279,15 @@ private[sql] class SparkResult[T](
     result
   }
 
+  /**
+   * Returns all observed metrics in the result.
+   */
+  def getObservedMetrics: Map[String, Row] = {
+    // We need to process all responses to get all metrics.
+    processResponses()
+    observedMetrics.toMap
+  }
+
   /**
    * Returns an iterator over the contents of the result.
    */
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index ce42cc797bf3..1f3496fa8984 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -204,7 +204,7 @@ object LiteralValueProtoConverter {
   def toLiteralProto(literal: Any, dataType: DataType): 
proto.Expression.Literal =
     toLiteralProtoBuilder(literal, dataType).build()
 
-  private def toDataType(clz: Class[_]): DataType = clz match {
+  private[sql] def toDataType(clz: Class[_]): DataType = clz match {
     // primitive types
     case JShort.TYPE => ShortType
     case JInteger.TYPE => IntegerType
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 0a6d12cbb191..4ef4f632204b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -220,6 +220,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
             .createObservedMetricsResponse(
               executeHolder.sessionHolder.sessionId,
               executeHolder.sessionHolder.serverSessionId,
+              executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
               observedMetrics ++ accumulatedInPython))
       }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 4f2b8c945127..660951f22984 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -264,8 +264,14 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
         name -> values
     }
     if (observedMetrics.nonEmpty) {
-      Some(SparkConnectPlanExecution
-        .createObservedMetricsResponse(sessionId, 
sessionHolder.serverSessionId, observedMetrics))
+      val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId
+      Some(
+        SparkConnectPlanExecution
+          .createObservedMetricsResponse(
+            sessionId,
+            sessionHolder.serverSessionId,
+            planId,
+            observedMetrics))
     } else None
   }
 }
@@ -274,11 +280,13 @@ object SparkConnectPlanExecution {
   def createObservedMetricsResponse(
       sessionId: String,
       serverSessionId: String,
+      planId: Long,
       metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = 
{
     val observedMetrics = metrics.map { case (name, values) =>
       val metrics = ExecutePlanResponse.ObservedMetrics
         .newBuilder()
         .setName(name)
+        .setPlanId(planId)
       values.foreach { case (key, value) =>
         metrics.addValues(toLiteralProto(value))
         key.foreach(metrics.addKeys)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 2a30ffe60a9f..a39396db4ff1 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -120,7 +120,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 5196
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5361
     _EXECUTEPLANRESPONSE._serialized_start = 5440
-    _EXECUTEPLANRESPONSE._serialized_end = 8230
+    _EXECUTEPLANRESPONSE._serialized_end = 8256
     _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 7030
     _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 7101
     _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7103
@@ -133,96 +133,96 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end 
= 7651
     _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7653
     _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7741
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7743
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7859
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7861
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7877
-    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7880
-    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8213
-    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8036
-    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8213
-    _KEYVALUE._serialized_start = 8232
-    _KEYVALUE._serialized_end = 8297
-    _CONFIGREQUEST._serialized_start = 8300
-    _CONFIGREQUEST._serialized_end = 9459
-    _CONFIGREQUEST_OPERATION._serialized_start = 8608
-    _CONFIGREQUEST_OPERATION._serialized_end = 9106
-    _CONFIGREQUEST_SET._serialized_start = 9108
-    _CONFIGREQUEST_SET._serialized_end = 9160
-    _CONFIGREQUEST_GET._serialized_start = 9162
-    _CONFIGREQUEST_GET._serialized_end = 9187
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9189
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9252
-    _CONFIGREQUEST_GETOPTION._serialized_start = 9254
-    _CONFIGREQUEST_GETOPTION._serialized_end = 9285
-    _CONFIGREQUEST_GETALL._serialized_start = 9287
-    _CONFIGREQUEST_GETALL._serialized_end = 9335
-    _CONFIGREQUEST_UNSET._serialized_start = 9337
-    _CONFIGREQUEST_UNSET._serialized_end = 9364
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9366
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9400
-    _CONFIGRESPONSE._serialized_start = 9462
-    _CONFIGRESPONSE._serialized_end = 9637
-    _ADDARTIFACTSREQUEST._serialized_start = 9640
-    _ADDARTIFACTSREQUEST._serialized_end = 10642
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10115
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10168
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10170
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10281
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10283
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10376
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10379
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10572
-    _ADDARTIFACTSRESPONSE._serialized_start = 10645
-    _ADDARTIFACTSRESPONSE._serialized_end = 10917
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10836
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10917
-    _ARTIFACTSTATUSESREQUEST._serialized_start = 10920
-    _ARTIFACTSTATUSESREQUEST._serialized_end = 11246
-    _ARTIFACTSTATUSESRESPONSE._serialized_start = 11249
-    _ARTIFACTSTATUSESRESPONSE._serialized_end = 11601
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11444
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11559
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11561
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11601
-    _INTERRUPTREQUEST._serialized_start = 11604
-    _INTERRUPTREQUEST._serialized_end = 12207
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12007
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12135
-    _INTERRUPTRESPONSE._serialized_start = 12210
-    _INTERRUPTRESPONSE._serialized_end = 12354
-    _REATTACHOPTIONS._serialized_start = 12356
-    _REATTACHOPTIONS._serialized_end = 12409
-    _REATTACHEXECUTEREQUEST._serialized_start = 12412
-    _REATTACHEXECUTEREQUEST._serialized_end = 12818
-    _RELEASEEXECUTEREQUEST._serialized_start = 12821
-    _RELEASEEXECUTEREQUEST._serialized_end = 13406
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13275
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13287
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13289
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13336
-    _RELEASEEXECUTERESPONSE._serialized_start = 13409
-    _RELEASEEXECUTERESPONSE._serialized_end = 13574
-    _RELEASESESSIONREQUEST._serialized_start = 13577
-    _RELEASESESSIONREQUEST._serialized_end = 13748
-    _RELEASESESSIONRESPONSE._serialized_start = 13750
-    _RELEASESESSIONRESPONSE._serialized_end = 13858
-    _FETCHERRORDETAILSREQUEST._serialized_start = 13861
-    _FETCHERRORDETAILSREQUEST._serialized_end = 14193
-    _FETCHERRORDETAILSRESPONSE._serialized_start = 14196
-    _FETCHERRORDETAILSRESPONSE._serialized_end = 15751
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14425
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14599
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14602
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14970
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
14933
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14970
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14973
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15382
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 15284
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 15352
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15385
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15732
-    _SPARKCONNECTSERVICE._serialized_start = 15754
-    _SPARKCONNECTSERVICE._serialized_end = 16700
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7744
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7885
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7887
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7903
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7906
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8239
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8062
+    _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8239
+    _KEYVALUE._serialized_start = 8258
+    _KEYVALUE._serialized_end = 8323
+    _CONFIGREQUEST._serialized_start = 8326
+    _CONFIGREQUEST._serialized_end = 9485
+    _CONFIGREQUEST_OPERATION._serialized_start = 8634
+    _CONFIGREQUEST_OPERATION._serialized_end = 9132
+    _CONFIGREQUEST_SET._serialized_start = 9134
+    _CONFIGREQUEST_SET._serialized_end = 9186
+    _CONFIGREQUEST_GET._serialized_start = 9188
+    _CONFIGREQUEST_GET._serialized_end = 9213
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9215
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9278
+    _CONFIGREQUEST_GETOPTION._serialized_start = 9280
+    _CONFIGREQUEST_GETOPTION._serialized_end = 9311
+    _CONFIGREQUEST_GETALL._serialized_start = 9313
+    _CONFIGREQUEST_GETALL._serialized_end = 9361
+    _CONFIGREQUEST_UNSET._serialized_start = 9363
+    _CONFIGREQUEST_UNSET._serialized_end = 9390
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9392
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9426
+    _CONFIGRESPONSE._serialized_start = 9488
+    _CONFIGRESPONSE._serialized_end = 9663
+    _ADDARTIFACTSREQUEST._serialized_start = 9666
+    _ADDARTIFACTSREQUEST._serialized_end = 10668
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10141
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10194
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10196
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10307
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10309
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10402
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10405
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10598
+    _ADDARTIFACTSRESPONSE._serialized_start = 10671
+    _ADDARTIFACTSRESPONSE._serialized_end = 10943
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10862
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10943
+    _ARTIFACTSTATUSESREQUEST._serialized_start = 10946
+    _ARTIFACTSTATUSESREQUEST._serialized_end = 11272
+    _ARTIFACTSTATUSESRESPONSE._serialized_start = 11275
+    _ARTIFACTSTATUSESRESPONSE._serialized_end = 11627
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11470
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11585
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11587
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11627
+    _INTERRUPTREQUEST._serialized_start = 11630
+    _INTERRUPTREQUEST._serialized_end = 12233
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12033
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12161
+    _INTERRUPTRESPONSE._serialized_start = 12236
+    _INTERRUPTRESPONSE._serialized_end = 12380
+    _REATTACHOPTIONS._serialized_start = 12382
+    _REATTACHOPTIONS._serialized_end = 12435
+    _REATTACHEXECUTEREQUEST._serialized_start = 12438
+    _REATTACHEXECUTEREQUEST._serialized_end = 12844
+    _RELEASEEXECUTEREQUEST._serialized_start = 12847
+    _RELEASEEXECUTEREQUEST._serialized_end = 13432
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13301
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13313
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13315
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13362
+    _RELEASEEXECUTERESPONSE._serialized_start = 13435
+    _RELEASEEXECUTERESPONSE._serialized_end = 13600
+    _RELEASESESSIONREQUEST._serialized_start = 13603
+    _RELEASESESSIONREQUEST._serialized_end = 13774
+    _RELEASESESSIONRESPONSE._serialized_start = 13776
+    _RELEASESESSIONRESPONSE._serialized_end = 13884
+    _FETCHERRORDETAILSREQUEST._serialized_start = 13887
+    _FETCHERRORDETAILSREQUEST._serialized_end = 14219
+    _FETCHERRORDETAILSRESPONSE._serialized_start = 14222
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 15777
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14451
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14625
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14628
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14996
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
14959
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14996
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14999
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15408
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 15310
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 15378
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15411
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15758
+    _SPARKCONNECTSERVICE._serialized_start = 15780
+    _SPARKCONNECTSERVICE._serialized_end = 16726
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index d22502f8839d..b76f2a7f4de3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1406,6 +1406,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         NAME_FIELD_NUMBER: builtins.int
         VALUES_FIELD_NUMBER: builtins.int
         KEYS_FIELD_NUMBER: builtins.int
+        PLAN_ID_FIELD_NUMBER: builtins.int
         name: builtins.str
         @property
         def values(
@@ -1417,6 +1418,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         def keys(
             self,
         ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 
...
+        plan_id: builtins.int
         def __init__(
             self,
             *,
@@ -1426,11 +1428,12 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             ]
             | None = ...,
             keys: collections.abc.Iterable[builtins.str] | None = ...,
+            plan_id: builtins.int = ...,
         ) -> None: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "keys", b"keys", "name", b"name", "values", b"values"
+                "keys", b"keys", "name", b"name", "plan_id", b"plan_id", 
"values", b"values"
             ],
         ) -> None: ...
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
new file mode 100644
index 000000000000..4789ae8975d1
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.jdk.CollectionConverters.MapHasAsJava
+
+/**
+ * Helper class to simplify usage of `Dataset.observe(String, Column, 
Column*)`:
+ *
+ * {{{
+ *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
+ *   val observation = Observation("my metrics")
+ *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
+ *   observed_ds.write.parquet("ds.parquet")
+ *   val metrics = observation.get
+ * }}}
+ *
+ * This collects the metrics while the first action is executed on the 
observed dataset. Subsequent
+ * actions do not modify the metrics returned by [[get]]. Retrieval of the 
metric via [[get]]
+ * blocks until the first action has finished and metrics become available.
+ *
+ * This class does not support streaming datasets.
+ *
+ * @param name name of the metric
+ * @since 3.3.0
+ */
+abstract class ObservationBase(val name: String) {
+
+  if (name.isEmpty) throw new IllegalArgumentException("Name must not be 
empty")
+
+  @volatile protected var metrics: Option[Map[String, Any]] = None
+
+  /**
+   * (Scala-specific) Get the observed metrics. This waits for the observed 
dataset to finish
+   * its first action. Only the result of the first action is available. 
Subsequent actions do not
+   * modify the result.
+   *
+   * @return the observed metrics as a `Map[String, Any]`
+   * @throws InterruptedException interrupted while waiting
+   */
+  @throws[InterruptedException]
+  def get: Map[String, _] = {
+    synchronized {
+      // we need to loop as wait might return without us calling notify
+      // 
https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610
+      while (this.metrics.isEmpty) {
+        wait()
+      }
+    }
+
+    this.metrics.get
+  }
+
+  /**
+   * (Java-specific) Get the observed metrics. This waits for the observed 
dataset to finish
+   * its first action. Only the result of the first action is available. 
Subsequent actions do not
+   * modify the result.
+   *
+   * @return the observed metrics as a `java.util.Map[String, Object]`
+   * @throws InterruptedException interrupted while waiting
+   */
+  @throws[InterruptedException]
+  def getAsJava: java.util.Map[String, AnyRef] = {
+    get.map { case (key, value) => (key, value.asInstanceOf[Object]) }.asJava
+  }
+
+  /**
+   * Get the observed metrics. This returns the metrics if they are available, 
otherwise an empty.
+   *
+   * @return the observed metrics as a `Map[String, Any]`
+   */
+  @throws[InterruptedException]
+  private[sql] def getOrEmpty: Map[String, _] = {
+    synchronized {
+      if (metrics.isEmpty) {
+        wait(100) // Wait for 100ms to see if metrics are available
+      }
+      metrics.getOrElse(Map.empty)
+    }
+  }
+
+  /**
+   * Set the observed metrics and notify all waiting threads to resume.
+   *
+   * @return `true` if all waiting threads were notified, `false` if otherwise.
+   */
+  private[spark] def setMetricsAndNotify(metrics: Option[Map[String, Any]]): 
Boolean = {
+    synchronized {
+      this.metrics = metrics
+      if(metrics.isDefined) {
+        notifyAll()
+        true
+      } else {
+        false
+      }
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index 104e7c101fd1..30d5943c6092 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql
 
 import java.util.UUID
 
-import scala.jdk.CollectionConverters.MapHasAsJava
-
 import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.QueryExecutionListener
@@ -47,9 +45,7 @@ import org.apache.spark.util.ArrayImplicits._
  * @param name name of the metric
  * @since 3.3.0
  */
-class Observation(val name: String) {
-
-  if (name.isEmpty) throw new IllegalArgumentException("Name must not be 
empty")
+class Observation(name: String) extends ObservationBase(name) {
 
   /**
    * Create an Observation instance without providing a name. This generates a 
random name.
@@ -60,8 +56,6 @@ class Observation(val name: String) {
 
   @volatile private var dataframeId: Option[(SparkSession, Long)] = None
 
-  @volatile private var metrics: Option[Map[String, Any]] = None
-
   /**
    * Attach this observation to the given [[Dataset]] to observe aggregation 
expressions.
    *
@@ -83,55 +77,6 @@ class Observation(val name: String) {
     ds.observe(name, expr, exprs: _*)
   }
 
-  /**
-   * (Scala-specific) Get the observed metrics. This waits for the observed 
dataset to finish
-   * its first action. Only the result of the first action is available. 
Subsequent actions do not
-   * modify the result.
-   *
-   * @return the observed metrics as a `Map[String, Any]`
-   * @throws InterruptedException interrupted while waiting
-   */
-  @throws[InterruptedException]
-  def get: Map[String, _] = {
-    synchronized {
-      // we need to loop as wait might return without us calling notify
-      // 
https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610
-      while (this.metrics.isEmpty) {
-        wait()
-      }
-    }
-
-    this.metrics.get
-  }
-
-  /**
-   * (Java-specific) Get the observed metrics. This waits for the observed 
dataset to finish
-   * its first action. Only the result of the first action is available. 
Subsequent actions do not
-   * modify the result.
-   *
-   * @return the observed metrics as a `java.util.Map[String, Object]`
-   * @throws InterruptedException interrupted while waiting
-   */
-  @throws[InterruptedException]
-  def getAsJava: java.util.Map[String, AnyRef] = {
-      get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava
-  }
-
-  /**
-   * Get the observed metrics. This returns the metrics if they are available, 
otherwise an empty.
-   *
-   * @return the observed metrics as a `Map[String, Any]`
-   */
-  @throws[InterruptedException]
-  private[sql] def getOrEmpty: Map[String, _] = {
-    synchronized {
-      if (metrics.isEmpty) {
-        wait(100) // Wait for 100ms to see if metrics are available
-      }
-      metrics.getOrElse(Map.empty)
-    }
-  }
-
   private[sql] def register(sparkSession: SparkSession, dataframeId: Long): 
Unit = {
     // makes this class thread-safe:
     // only the first thread entering this block can set sparkSession
@@ -158,9 +103,8 @@ class Observation(val name: String) {
         case _ => false
       }) {
         val row = qe.observedMetrics.get(name)
-        this.metrics = row.map(r => 
r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq))
-        if (metrics.isDefined) {
-          notifyAll()
+        val metrics = row.map(r => 
r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq))
+        if (setMetricsAndNotify(metrics)) {
           unregister()
         }
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to