Repository: spark
Updated Branches:
  refs/heads/master 87f82a5fb -> 07ced4342


[SPARK-11253] [SQL] reset all accumulators in physical operators before execute 
an action

With this change, our query execution listener can get the metrics correctly.

The UI still looks good after this change.
<img width="257" alt="screen shot 2015-10-23 at 11 25 14 am" 
src="https://cloud.githubusercontent.com/assets/3182036/10683834/d516f37e-7978-11e5-8118-343ed40eb824.png";>
<img width="494" alt="screen shot 2015-10-23 at 11 25 01 am" 
src="https://cloud.githubusercontent.com/assets/3182036/10683837/e1fa60da-7978-11e5-8ec8-178b88f27764.png";>

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9215 from cloud-fan/metric.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/07ced434
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/07ced434
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/07ced434

Branch: refs/heads/master
Commit: 07ced43424447699e47106de9ca2fa714756bdeb
Parents: 87f82a5
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Sun Oct 25 22:47:39 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Sun Oct 25 22:47:39 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/DataFrame.scala  |  3 +
 .../spark/sql/execution/metric/SQLMetrics.scala |  7 +-
 .../spark/sql/util/DataFrameCallbackSuite.scala | 81 +++++++++++++++++++-
 3 files changed, 87 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/07ced434/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index bf25bcd..25ad3bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1974,6 +1974,9 @@ class DataFrame private[sql](
    */
   private def withCallback[T](name: String, df: DataFrame)(action: DataFrame 
=> T) = {
     try {
+      df.queryExecution.executedPlan.foreach { plan =>
+        plan.metrics.valuesIterator.foreach(_.reset())
+      }
       val start = System.nanoTime()
       val result = action(df)
       val end = System.nanoTime()

http://git-wip-us.apache.org/repos/asf/spark/blob/07ced434/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 075b7ad..1c253e3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -28,7 +28,12 @@ import org.apache.spark.{Accumulable, AccumulableParam, 
SparkContext}
  */
 private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
     name: String, val param: SQLMetricParam[R, T])
-  extends Accumulable[R, T](param.zero, param, Some(name), true)
+  extends Accumulable[R, T](param.zero, param, Some(name), true) {
+
+  def reset(): Unit = {
+    this.value = param.zero
+  }
+}
 
 /**
  * Create a layer for specialized metric. We cannot add `@specialized` to

http://git-wip-us.apache.org/repos/asf/spark/blob/07ced434/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb056cd..b46b0d2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -17,14 +17,14 @@
 
 package org.apache.spark.sql.util
 
-import org.apache.spark.SparkException
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
 import org.apache.spark.sql.{functions, QueryTest}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.test.SharedSQLContext
 
-import scala.collection.mutable.ArrayBuffer
-
 class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
   import functions._
@@ -54,6 +54,8 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     assert(metrics(1)._1 == "count")
     assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate])
     assert(metrics(1)._3 > 0)
+
+    sqlContext.listenerManager.unregister(listener)
   }
 
   test("execute callback functions when a DataFrame action failed") {
@@ -79,5 +81,78 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     assert(metrics(0)._1 == "collect")
     assert(metrics(0)._2.analyzed.isInstanceOf[Project])
     assert(metrics(0)._3.getMessage == e.getMessage)
+
+    sqlContext.listenerManager.unregister(listener)
+  }
+
+  test("get numRows metrics by callback") {
+    val metrics = ArrayBuffer.empty[Long]
+    val listener = new QueryExecutionListener {
+      // Only test successful case here, so no need to implement `onFailure`
+      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
+
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
+        metrics += qe.executedPlan.longMetric("numInputRows").value.value
+      }
+    }
+    sqlContext.listenerManager.register(listener)
+
+    val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+    df.collect()
+    df.collect()
+    Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+
+    assert(metrics.length == 3)
+    assert(metrics(0) == 1)
+    assert(metrics(1) == 1)
+    assert(metrics(2) == 2)
+
+    sqlContext.listenerManager.unregister(listener)
+  }
+
+  // TODO: Currently some LongSQLMetric use -1 as initial value, so if the 
accumulator is never
+  // updated, we can filter it out later.  However, when we aggregate(sum) 
accumulator values at
+  // driver side for SQL physical operators, these -1 values will make our 
result smaller.
+  // A easy fix is to create a new SQLMetric(including new MetricValue, 
MetricParam, etc.), but we
+  // can do it later because the impact is just too small (1048576 tasks for 1 
MB).
+  ignore("get size metrics by callback") {
+    val metrics = ArrayBuffer.empty[Long]
+    val listener = new QueryExecutionListener {
+      // Only test successful case here, so no need to implement `onFailure`
+      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
+
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
+        metrics += qe.executedPlan.longMetric("dataSize").value.value
+        val bottomAgg = qe.executedPlan.children(0).children(0)
+        metrics += bottomAgg.longMetric("dataSize").value.value
+      }
+    }
+    sqlContext.listenerManager.register(listener)
+
+    val sparkListener = new SaveInfoListener
+    sqlContext.sparkContext.addSparkListener(sparkListener)
+
+    val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j")
+    df.groupBy("i").count().collect()
+
+    def getPeakExecutionMemory(stageId: Int): Long = {
+      val peakMemoryAccumulator = 
sparkListener.getCompletedStageInfos(stageId).accumulables
+        .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
+
+      assert(peakMemoryAccumulator.size == 1)
+      peakMemoryAccumulator.head._2.value.toLong
+    }
+
+    assert(sparkListener.getCompletedStageInfos.length == 2)
+    val bottomAggDataSize = getPeakExecutionMemory(0)
+    val topAggDataSize = getPeakExecutionMemory(1)
+
+    // For this simple case, the peakExecutionMemory of a stage should be the 
data size of the
+    // aggregate operator, as we only have one memory consuming operator per 
stage.
+    assert(metrics.length == 2)
+    assert(metrics(0) == topAggDataSize)
+    assert(metrics(1) == bottomAggDataSize)
+
+    sqlContext.listenerManager.unregister(listener)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to