This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a1649ad2429 [SPARK-42034] QueryExecutionListener and Observation API do not work with `foreach` / `reduce` / `foreachPartition` action a1649ad2429 is described below commit a1649ad24298d988267acb8588d19848c7fb16c4 Author: 佘志铭 <shezhim...@corp.netease.com> AuthorDate: Mon Feb 13 14:09:54 2023 +0900 [SPARK-42034] QueryExecutionListener and Observation API do not work with `foreach` / `reduce` / `foreachPartition` action ### What changes were proposed in this pull request? Add the name parameter for 'foreach'/'reduce'/'foreachPartition' operators in `DataSet#withNewRDDExecutionId`. Because the QueryExecutionListener and Observation API is triggered only when the operators have the name parameter. https://github.com/apache/spark/blob/84ddd409c11e4da769c5b1f496f2b61c3d928c07/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala#L181 ### Why are the changes needed? The QueryExecutionListener and Observation API is triggered only when the operators have the name parameter. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add two unit test. Closes #39976 from zzzzming95/SPARK-42034. Authored-by: 佘志铭 <shezhim...@corp.netease.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++---- .../scala/org/apache/spark/sql/DatasetSuite.scala | 13 ++++++++++ .../spark/sql/util/DataFrameCallbackSuite.scala | 28 ++++++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 28177b90c7e..edcfad0c798 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1858,7 +1858,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + def reduce(func: (T, T) => T): T = withNewRDDExecutionId("reduce") { rdd.reduce(func) } @@ -3336,7 +3336,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreach(f: T => Unit): Unit = withNewRDDExecutionId { + def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") { rdd.foreach(f) } @@ -3355,7 +3355,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId { + def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") { rdd.foreachPartition(f) } @@ -4148,8 +4148,8 @@ class Dataset[T] private[sql]( * them with an execution. Before performing the action, the metrics of the executed plan will be * reset. */ - private def withNewRDDExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(rddQueryExecution) { + private def withNewRDDExecutionId[U](name: String)(body: => U): U = { + SQLExecution.withNewExecutionId(rddQueryExecution, Some(name)) { rddQueryExecution.executedPlan.resetMetrics() body } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 86e640a4fa8..263e361413c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -960,6 +960,19 @@ class DatasetSuite extends QueryTest observe(spark.range(1, 10, 1, 11), Map("percentile_approx_val" -> 5)) } + test("observation on datasets when a DataSet trigger foreach action") { + def f(): Unit = {} + + val namedObservation = Observation("named") + val observed_df = spark.range(100).observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + + observed_df.foreach(r => f) + val expected = Map("percentile_approx_val" -> 49) + + assert(namedObservation.get === expected) + } + test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 2fc1f10d3ea..f046daacb91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -96,6 +96,34 @@ class DataFrameCallbackSuite extends QueryTest spark.listenerManager.unregister(listener) } + test("execute callback functions when a DataSet trigger foreach action finished") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += ((funcName, qe, duration)) + } + } + spark.listenerManager.register(listener) + + def f(): Unit = {} + + val df = Seq(1).toDF("i") + + df.foreach(r => f) + df.reduce((x, y) => x) + + sparkContext.listenerBus.waitUntilEmpty() + assert(metrics.length == 2) + + assert(metrics(0)._1 == "foreach") + assert(metrics(1)._1 == "reduce") + + spark.listenerManager.unregister(listener) + } + test("get numRows metrics by callback") { val metrics = ArrayBuffer.empty[Long] val listener = new QueryExecutionListener { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org