Repository: spark
Updated Branches:
  refs/heads/master e99e34d0f -> 6ed285c68


[SPARK-19447] Fixing input metrics for range operator.

## What changes were proposed in this pull request?

This change introduces a new metric "number of generated rows". It is used 
exclusively for Range, which is a leaf in the query tree, yet doesn't read any 
input data, and therefore cannot report "recordsRead".

Additionally the way in which the metrics are reported by the JIT-compiled 
version of Range was changed. Previously, it was immediately reported that all 
the records were produced. This could be confusing for a user monitoring 
execution progress in the UI. Now, the metric is updated gradually.

In order to avoid negative impact on Range performance, the code generation was 
reworked. The values are now produced in batches in the tighter inner loop, 
while the metrics are updated in the outer loop.

The change also contains a number of unit tests, which should help ensure the 
correctness of metrics for various input sources.

## How was this patch tested?

Unit tests.

Author: Ala Luszczak <a...@databricks.com>

Closes #16829 from ala/SPARK-19447.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ed285c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ed285c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ed285c6

Branch: refs/heads/master
Commit: 6ed285c68fee451c45db7b01ca8ec1dea2efd479
Parents: e99e34d
Author: Ala Luszczak <a...@databricks.com>
Authored: Tue Feb 7 14:21:30 2017 +0100
Committer: Reynold Xin <r...@databricks.com>
Committed: Tue Feb 7 14:21:30 2017 +0100

----------------------------------------------------------------------
 .../sql/execution/basicPhysicalOperators.scala  |  82 ++++++++----
 .../apache/spark/sql/DataFrameRangeSuite.scala  | 130 ++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  53 --------
 .../InputGeneratedOutputMetricsSuite.scala      | 131 +++++++++++++++++++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |  10 ++
 .../sql/hive/execution/HiveSerDeSuite.scala     |  19 +++
 6 files changed, 350 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index fb90799..792fb3e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -339,7 +339,8 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
   override val output: Seq[Attribute] = range.output
 
   override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of 
generated rows"))
 
   // output attributes should not affect the results
   override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, 
numElements)
@@ -351,24 +352,37 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
 
   protected override def doProduce(ctx: CodegenContext): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
+    val numGenerated = metricTerm(ctx, "numGeneratedRows")
 
     val initTerm = ctx.freshName("initRange")
     ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
-    val partitionEnd = ctx.freshName("partitionEnd")
-    ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
     val number = ctx.freshName("number")
     ctx.addMutableState("long", number, s"$number = 0L;")
-    val overflow = ctx.freshName("overflow")
-    ctx.addMutableState("boolean", overflow, s"$overflow = false;")
 
     val value = ctx.freshName("value")
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
-    val checkEnd = if (step > 0) {
-      s"$number < $partitionEnd"
-    } else {
-      s"$number > $partitionEnd"
-    }
+
+    // In order to periodically update the metrics without inflicting 
performance penalty, this
+    // operator produces elements in batches. After a batch is complete, the 
metrics are updated
+    // and a new batch is started.
+    // In the implementation below, the code in the inner loop is producing 
all the values
+    // within a batch, while the code in the outer loop is setting batch 
parameters and updating
+    // the metrics.
+
+    // Once number == batchEnd, it's time to progress to the next batch.
+    val batchEnd = ctx.freshName("batchEnd")
+    ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;")
+
+    // How many values should still be generated by this range operator.
+    val numElementsTodo = ctx.freshName("numElementsTodo")
+    ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;")
+
+    // How many values should be generated in the next batch.
+    val nextBatchTodo = ctx.freshName("nextBatchTodo")
+
+    // The default size of a batch.
+    val batchSize = 1000L
 
     ctx.addNewFunction("initRange",
       s"""
@@ -378,6 +392,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
         |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
         |   $BigInt step = $BigInt.valueOf(${step}L);
         |   $BigInt start = $BigInt.valueOf(${start}L);
+        |   long partitionEnd;
         |
         |   $BigInt st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
         |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
@@ -387,18 +402,26 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
         |   } else {
         |     $number = st.longValue();
         |   }
+        |   $batchEnd = $number;
         |
         |   $BigInt end = 
index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
         |     .multiply(step).add(start);
         |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $partitionEnd = Long.MAX_VALUE;
+        |     partitionEnd = Long.MAX_VALUE;
         |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $partitionEnd = Long.MIN_VALUE;
+        |     partitionEnd = Long.MIN_VALUE;
         |   } else {
-        |     $partitionEnd = end.longValue();
+        |     partitionEnd = end.longValue();
         |   }
         |
-        |   $numOutput.add(($partitionEnd - $number) / ${step}L);
+        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
+        |     $BigInt.valueOf($number));
+        |   $numElementsTodo  = startToEnd.divide(step).longValue();
+        |   if ($numElementsTodo < 0) {
+        |     $numElementsTodo = 0;
+        |   } else if 
(startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
+        |     $numElementsTodo++;
+        |   }
         | }
        """.stripMargin)
 
@@ -412,20 +435,34 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
       |   initRange(partitionIndex);
       | }
       |
-      | while (!$overflow && $checkEnd) {
-      |  long $value = $number;
-      |  $number += ${step}L;
-      |  if ($number < $value ^ ${step}L < 0) {
-      |    $overflow = true;
-      |  }
-      |  ${consume(ctx, Seq(ev))}
-      |  if (shouldStop()) return;
+      | while (true) {
+      |   while ($number != $batchEnd) {
+      |     long $value = $number;
+      |     $number += ${step}L;
+      |     ${consume(ctx, Seq(ev))}
+      |     if (shouldStop()) return;
+      |   }
+      |
+      |   long $nextBatchTodo;
+      |   if ($numElementsTodo > ${batchSize}L) {
+      |     $nextBatchTodo = ${batchSize}L;
+      |     $numElementsTodo -= ${batchSize}L;
+      |   } else {
+      |     $nextBatchTodo = $numElementsTodo;
+      |     $numElementsTodo = 0;
+      |     if ($nextBatchTodo == 0) break;
+      |   }
+      |   $numOutput.add($nextBatchTodo);
+      |   $numGenerated.add($nextBatchTodo);
+      |
+      |   $batchEnd += $nextBatchTodo * ${step}L;
       | }
      """.stripMargin
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val numGeneratedRows = longMetric("numGeneratedRows")
     sqlContext
       .sparkContext
       .parallelize(0 until numSlices, numSlices)
@@ -469,6 +506,7 @@ case class RangeExec(range: 
org.apache.spark.sql.catalyst.plans.logical.Range)
             }
 
             numOutputRows += 1
+            numGeneratedRows += 1
             unsafeRow.setLong(0, ret)
             unsafeRow
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
new file mode 100644
index 0000000..6d2d776
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.math.abs
+import scala.util.Random
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+  test("SPARK-7150 range api") {
+    // numSlice is greater than length
+    val res1 = spark.range(0, 10, 1, 15).select("id")
+    assert(res1.count == 10)
+    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res2 = spark.range(3, 15, 3, 2).select("id")
+    assert(res2.count == 4)
+    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    val res3 = spark.range(1, -2).select("id")
+    assert(res3.count == 0)
+
+    // start is positive, end is negative, step is negative
+    val res4 = spark.range(1, -2, -2, 6).select("id")
+    assert(res4.count == 2)
+    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+    // start, end, step are negative
+    val res5 = spark.range(-3, -8, -2, 1).select("id")
+    assert(res5.count == 3)
+    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+    // start, end are negative, step is positive
+    val res6 = spark.range(-8, -4, 2, 1).select("id")
+    assert(res6.count == 2)
+    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+    val res7 = spark.range(-10, -9, -20, 1).select("id")
+    assert(res7.count == 0)
+
+    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 
100).select("id")
+    assert(res8.count == 3)
+    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 
100).select("id")
+    assert(res9.count == 2)
+    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue 
- 1)))
+
+    // only end provided as argument
+    val res10 = spark.range(10).select("id")
+    assert(res10.count == 10)
+    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res11 = spark.range(-1).select("id")
+    assert(res11.count == 0)
+
+    // using the default slice number
+    val res12 = spark.range(3, 15, 3).select("id")
+    assert(res12.count == 4)
+    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    // difference between range start and end does not fit in a 64-bit integer
+    val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
+    val res13 = spark.range(-n, n, n / 9).select("id")
+    assert(res13.count == 18)
+  }
+
+  test("Range with randomized parameters") {
+    val MAX_NUM_STEPS = 10L * 1000
+
+    val seed = System.currentTimeMillis()
+    val random = new Random(seed)
+
+    def randomBound(): Long = {
+      val n = if (random.nextBoolean()) {
+        random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
+      } else {
+        random.nextLong() / 2
+      }
+      if (random.nextBoolean()) n else -n
+    }
+
+    for (l <- 1 to 10) {
+      val start = randomBound()
+      val end = randomBound()
+      val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1
+      val stepAbs = (abs(end - start) / numSteps) + 1
+      val step = if (start < end) stepAbs else -stepAbs
+      val partitions = random.nextInt(20) + 1
+
+      val expCount = (start until end by step).size
+      val expSum = (start until end by step).sum
+
+      for (codegen <- List(false, true)) {
+        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
codegen.toString()) {
+          val res = spark.range(start, end, step, partitions).toDF("id").
+            agg(count("id"), sum("id")).collect()
+
+          withClue(s"seed = $seed start = $start end = $end step = $step 
partitions = " +
+              s"$partitions codegen = $codegen") {
+            assert(!res.isEmpty)
+            assert(res.head.getLong(0) == expCount)
+            if (expCount > 0) {
+              assert(res.head.getLong(1) == expSum)
+            }
+          }
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6a190b9..e6338ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
       Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
   }
 
-  test("SPARK-7150 range api") {
-    // numSlice is greater than length
-    val res1 = spark.range(0, 10, 1, 15).select("id")
-    assert(res1.count == 10)
-    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res2 = spark.range(3, 15, 3, 2).select("id")
-    assert(res2.count == 4)
-    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-
-    val res3 = spark.range(1, -2).select("id")
-    assert(res3.count == 0)
-
-    // start is positive, end is negative, step is negative
-    val res4 = spark.range(1, -2, -2, 6).select("id")
-    assert(res4.count == 2)
-    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
-
-    // start, end, step are negative
-    val res5 = spark.range(-3, -8, -2, 1).select("id")
-    assert(res5.count == 3)
-    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
-
-    // start, end are negative, step is positive
-    val res6 = spark.range(-8, -4, 2, 1).select("id")
-    assert(res6.count == 2)
-    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
-
-    val res7 = spark.range(-10, -9, -20, 1).select("id")
-    assert(res7.count == 0)
-
-    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 
100).select("id")
-    assert(res8.count == 3)
-    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
-    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 
100).select("id")
-    assert(res9.count == 2)
-    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue 
- 1)))
-
-    // only end provided as argument
-    val res10 = spark.range(10).select("id")
-    assert(res10.count == 10)
-    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res11 = spark.range(-1).select("id")
-    assert(res11.count == 0)
-
-    // using the default slice number
-    val res12 = spark.range(3, 15, 3).select("id")
-    assert(res12.count == 4)
-    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-  }
-
   test("SPARK-8621: support empty string column name") {
     val df = Seq(Tuple1(1)).toDF("").as("t")
     // We should allow empty string as column name

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
new file mode 100644
index 0000000..ddd7a03
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.File
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext 
with Eventually {
+
+  test("Range query input/output/generated metrics") {
+    val numRows = 150L
+    val numSelectedRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
+      filter(x => x < numSelectedRows).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === numSelectedRows :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with repartitioning") {
+    val numRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(
+      spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === numRows)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === 20 :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with more repartitioning") {
+    withTempDir { tempDir =>
+      val dir = new File(tempDir, "pqS").getCanonicalPath
+
+      spark.range(10).write.parquet(dir)
+      spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+      val res = MetricsTestHelper.runAndGetMetrics(
+        spark.range(0, 30).repartition(3).crossJoin(sql("select * from 
pqS")).repartition(2)
+            .toDF()
+      )
+
+      assert(res.recordsRead.sum == 10)
+      assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
+      assert(res.generatedRows == 30 :: Nil)
+      assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
+    }
+  }
+}
+
+object MetricsTestHelper {
+  case class AggregatedMetricsResult(
+      recordsRead: List[Long],
+      shuffleRecordsRead: List[Long],
+      generatedRows: List[Long],
+      outputRows: List[Long])
+
+  private[this] def extractMetricValues(
+      df: DataFrame,
+      metricValues: Map[Long, String],
+      metricName: String): List[Long] = {
+    df.queryExecution.executedPlan.collect {
+      case plan if plan.metrics.contains(metricName) =>
+        metricValues(plan.metrics(metricName).id).toLong
+    }.toList.sorted
+  }
+
+  def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
+      AggregatedMetricsResult = {
+    val spark = df.sparkSession
+    val sparkContext = spark.sparkContext
+
+    var recordsRead = List[Long]()
+    var shuffleRecordsRead = List[Long]()
+    val listener = new SparkListener() {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+        if (taskEnd.taskMetrics != null) {
+          recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
+            recordsRead
+          shuffleRecordsRead = 
taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
+            shuffleRecordsRead
+        }
+      }
+    }
+
+    val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+
+    val prevUseWholeStageCodeGen =
+      spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
+    try {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, 
useWholeStageCodeGen)
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+      sparkContext.addSparkListener(listener)
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+    } finally {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, 
prevUseWholeStageCodeGen)
+    }
+
+    val executionId = 
spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
+    val metricValues = 
spark.sharedState.listener.getExecutionMetrics(executionId)
+    val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
+    val generatedRes = extractMetricValues(df, metricValues, 
"numGeneratedRows")
+
+    AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, 
generatedRes, outputRes)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 0396254..14fbe9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, 
JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite
     }.getMessage
     assert(e2.contains("User specified schema not supported with `jdbc`"))
   }
+
+  test("Input/generated/output metrics on JDBC") {
+    val foobarCnt = spark.table("foobar").count()
+    val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM 
foobar").toDF())
+    assert(res.recordsRead === foobarCnt :: Nil)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows.isEmpty)
+    assert(res.outputRows === foobarCnt :: Nil)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index ec620c2..35c41b5 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.hive.test.TestHive
 
 /**
@@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with 
BeforeAndAfterAll {
   createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
 
   createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM 
episodes_part")
+
+  test("Test input/generated/output metrics") {
+    import TestHive._
+
+    val episodesCnt = sql("select * from episodes").count()
+    val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from 
episodes").toDF())
+    assert(episodesRes.recordsRead === episodesCnt :: Nil)
+    assert(episodesRes.shuffleRecordsRead.sum === 0)
+    assert(episodesRes.generatedRows.isEmpty)
+    assert(episodesRes.outputRows === episodesCnt :: Nil)
+
+    val serdeinsCnt = sql("select * from serdeins").count()
+    val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from 
serdeins").toDF())
+    assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
+    assert(serdeinsRes.shuffleRecordsRead.sum === 0)
+    assert(serdeinsRes.generatedRows.isEmpty)
+    assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to