This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6cd5cf5  [SPARK-35215][SQL] Update custom metric per certain rows and 
at the end of the task
6cd5cf5 is described below

commit 6cd5cf57229050ba9542a644ed0a4c844949a832
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Thu May 6 13:21:08 2021 +0000

    [SPARK-35215][SQL] Update custom metric per certain rows and at the end of 
the task
    
    ### What changes were proposed in this pull request?
    
    This patch changes custom metric updating to update per certain rows 
(currently 100), instead of per row.
    
    ### Why are the changes needed?
    
    Based on previous discussion 
https://github.com/apache/spark/pull/31451#discussion_r605413557, we should 
only update custom metrics per certain (e.g. 100) rows and also at the end of 
the task. Updating per row doesn't make too much benefit.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing unit test.
    
    Closes #32330 from viirya/metric-update.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/datasources/v2/DataSourceRDD.scala  | 19 ++++++++++++-------
 .../spark/sql/execution/metric/CustomMetrics.scala    | 15 ++++++++++++++-
 .../continuous/ContinuousDataSourceRDD.scala          |  9 ++++++---
 3 files changed, 32 insertions(+), 11 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 7850dfa..217a1d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 class DataSourceRDDPartition(val index: Int, val inputPartition: 
InputPartition)
@@ -66,7 +66,12 @@ class DataSourceRDD(
         new PartitionIterator[InternalRow](rowReader, customMetrics))
       (iter, rowReader)
     }
-    context.addTaskCompletionListener[Unit](_ => reader.close())
+    context.addTaskCompletionListener[Unit] { _ =>
+      // In case of early stopping before consuming the entire iterator,
+      // we need to do one more metric update at the end of the task.
+      CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+      reader.close()
+    }
     // TODO: SPARK-25083 remove the type erasure hack in data source scan
     new InterruptibleIterator(context, 
iter.asInstanceOf[Iterator[InternalRow]])
   }
@@ -81,6 +86,8 @@ private class PartitionIterator[T](
     customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
   private[this] var valuePrepared = false
 
+  private var numRow = 0L
+
   override def hasNext: Boolean = {
     if (!valuePrepared) {
       valuePrepared = reader.next()
@@ -92,12 +99,10 @@ private class PartitionIterator[T](
     if (!hasNext) {
       throw QueryExecutionErrors.endOfStreamError()
     }
-    reader.currentMetricsValues.foreach { metric =>
-      assert(customMetrics.contains(metric.name()),
-        s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain 
the metric " +
-          s"${metric.name()}")
-      customMetrics(metric.name()).set(metric.value())
+    if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
+      CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
     }
+    numRow += 1
     valuePrepared = false
     reader.get()
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
index f2449a1..3e6cad2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.sql.execution.metric
 
-import org.apache.spark.sql.connector.metric.CustomMetric
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
 
 object CustomMetrics {
   private[spark] val V2_CUSTOM = "v2Custom"
 
+  private[spark] val NUM_ROWS_PER_UPDATE = 100
+
   /**
    * Given a class name, builds and returns a metric type for a V2 custom 
metric class
    * `CustomMetric`.
@@ -41,4 +43,15 @@ object CustomMetrics {
       None
     }
   }
+
+  /**
+   * Updates given custom metrics.
+   */
+  def updateMetrics(
+      currentMetricsValues: Seq[CustomTaskMetric],
+      customMetrics: Map[String, SQLMetric]): Unit = {
+    currentMetricsValues.foreach { metric =>
+      customMetrics(metric.name()).set(metric.value())
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
index 4e32cef..6d27961 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -22,7 +22,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.read.InputPartition
 import 
org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory
-import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.NextIterator
 
@@ -92,10 +92,13 @@ class ContinuousDataSourceRDD(
 
     val partitionReader = readerForPartition.getPartitionReader()
     new NextIterator[InternalRow] {
+      private var numRow = 0L
+
       override def getNext(): InternalRow = {
-        partitionReader.currentMetricsValues.foreach { metric =>
-          customMetrics(metric.name()).set(metric.value())
+        if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
+          CustomMetrics.updateMetrics(partitionReader.currentMetricsValues, 
customMetrics)
         }
+        numRow += 1
         readerForPartition.next() match {
           case null =>
             finished = true

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to