This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7db9b2293fa [SPARK-45656][SQL] Fix observation when named observations 
with the same name on different datasets
7db9b2293fa is described below

commit 7db9b2293fa778073274d235dd72212b75d94073
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Wed Oct 25 16:59:26 2023 +0900

    [SPARK-45656][SQL] Fix observation when named observations with the same 
name on different datasets
    
    ### What changes were proposed in this pull request?
    
    Fixes observation when named observations with the same name on different 
datasets.
    
    ### Why are the changes needed?
    
    Currently if there are observations with the same name on different 
dataset, one of them will be overwritten by the other execution.
    
    For example,
    
    ```py
    >>> observation1 = Observation("named")
    >>> df1 = spark.range(50)
    >>> observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
    >>>
    >>> observation2 = Observation("named")
    >>> df2 = spark.range(100)
    >>> observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
    >>>
    >>> observed_df1.collect()
    ...
    >>> observed_df2.collect()
    ...
    >>> observation1.get
    {'cnt': 50}
    >>> observation2.get
    {'cnt': 50}
    ```
    
    `observation2` should return `{'cnt': 100}`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the observations with the same name will be available if they observe 
different datasets.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43519 from ueshin/issues/SPARK-45656/observation.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_dataframe.py          | 18 ++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala   |  2 +-
 .../scala/org/apache/spark/sql/Observation.scala    | 21 +++++++++++++--------
 .../scala/org/apache/spark/sql/DatasetSuite.scala   | 21 +++++++++++++++++++++
 4 files changed, 53 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 3c493a8ae3a..0a2e3a53946 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1023,6 +1023,24 @@ class DataFrameTestsMixin:
         self.assertGreaterEqual(row.cnt, 0)
         self.assertGreaterEqual(row.sum, 0)
 
+    def test_observe_with_same_name_on_different_dataframe(self):
+        # SPARK-45656: named observations with the same name on different 
datasets
+        from pyspark.sql import Observation
+
+        observation1 = Observation("named")
+        df1 = self.spark.range(50)
+        observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
+
+        observation2 = Observation("named")
+        df2 = self.spark.range(100)
+        observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
+
+        observed_df1.collect()
+        observed_df2.collect()
+
+        self.assertEqual(observation1.get, dict(cnt=50))
+        self.assertEqual(observation2.get, dict(cnt=100))
+
     def test_sample(self):
         with self.assertRaises(PySparkTypeError) as pe:
             self.spark.range(1).sample()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5079cfcca9d..4f07133bb76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -201,7 +201,7 @@ class Dataset[T] private[sql](
   }
 
   // A globally unique id of this Dataset.
-  private val id = Dataset.curId.getAndIncrement()
+  private[sql] val id = Dataset.curId.getAndIncrement()
 
   queryExecution.assertAnalyzed()
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index ba40336fc14..14c4983794b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -21,6 +21,7 @@ import java.util.UUID
 
 import scala.jdk.CollectionConverters.MapHasAsJava
 
+import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.QueryExecutionListener
 
@@ -56,7 +57,7 @@ class Observation(val name: String) {
 
   private val listener: ObservationListener = ObservationListener(this)
 
-  @volatile private var sparkSession: Option[SparkSession] = None
+  @volatile private var ds: Option[Dataset[_]] = None
 
   @volatile private var metrics: Option[Map[String, Any]] = None
 
@@ -74,7 +75,7 @@ class Observation(val name: String) {
     if (ds.isStreaming) {
       throw new IllegalArgumentException("Observation does not support 
streaming Datasets")
     }
-    register(ds.sparkSession)
+    register(ds)
     ds.observe(name, expr, exprs: _*)
   }
 
@@ -112,27 +113,31 @@ class Observation(val name: String) {
       get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava
   }
 
-  private def register(sparkSession: SparkSession): Unit = {
+  private def register(ds: Dataset[_]): Unit = {
     // makes this class thread-safe:
     // only the first thread entering this block can set sparkSession
     // all other threads will see the exception, as it is only allowed to do 
this once
     synchronized {
-      if (this.sparkSession.isDefined) {
+      if (this.ds.isDefined) {
         throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
       }
-      this.sparkSession = Some(sparkSession)
+      this.ds = Some(ds)
     }
 
-    sparkSession.listenerManager.register(this.listener)
+    ds.sparkSession.listenerManager.register(this.listener)
   }
 
   private def unregister(): Unit = {
-    this.sparkSession.foreach(_.listenerManager.unregister(this.listener))
+    this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener))
   }
 
   private[spark] def onFinish(qe: QueryExecution): Unit = {
     synchronized {
-      if (this.metrics.isEmpty) {
+      if (this.metrics.isEmpty && qe.logical.exists {
+        case CollectMetrics(name, _, _, dataframeId) =>
+          name == this.name && dataframeId == ds.get.id
+        case _ => false
+      }) {
         val row = qe.observedMetrics.get(name)
         this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames))
         if (metrics.isDefined) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 51fa3cd5916..6b00799cabd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1024,6 +1024,27 @@ class DatasetSuite extends QueryTest
     assert(namedObservation.get === expected)
   }
 
+  test("SPARK-45656: named observations with the same name on different 
datasets") {
+    val namedObservation1 = Observation("named")
+    val df1 = spark.range(50)
+    val observed_df1 = df1.observe(
+      namedObservation1, count(lit(1)).as("count"))
+
+    val namedObservation2 = Observation("named")
+    val df2 = spark.range(100)
+    val observed_df2 = df2.observe(
+      namedObservation2, count(lit(1)).as("count"))
+
+    observed_df1.collect()
+    observed_df2.collect()
+
+    val expected1 = Map("count" -> 50)
+    val expected2 = Map("count" -> 100)
+
+    assert(namedObservation1.get === expected1)
+    assert(namedObservation2.get === expected2)
+  }
+
   test("sample with replacement") {
     val n = 100
     val data = sparkContext.parallelize(1 to n, 2).toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to