Repository: spark Updated Branches: refs/heads/master e99e34d0f -> 6ed285c68
[SPARK-19447] Fixing input metrics for range operator. ## What changes were proposed in this pull request? This change introduces a new metric "number of generated rows". It is used exclusively for Range, which is a leaf in the query tree, yet doesn't read any input data, and therefore cannot report "recordsRead". Additionally the way in which the metrics are reported by the JIT-compiled version of Range was changed. Previously, it was immediately reported that all the records were produced. This could be confusing for a user monitoring execution progress in the UI. Now, the metric is updated gradually. In order to avoid negative impact on Range performance, the code generation was reworked. The values are now produced in batches in the tighter inner loop, while the metrics are updated in the outer loop. The change also contains a number of unit tests, which should help ensure the correctness of metrics for various input sources. ## How was this patch tested? Unit tests. Author: Ala Luszczak <a...@databricks.com> Closes #16829 from ala/SPARK-19447. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ed285c6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ed285c6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ed285c6 Branch: refs/heads/master Commit: 6ed285c68fee451c45db7b01ca8ec1dea2efd479 Parents: e99e34d Author: Ala Luszczak <a...@databricks.com> Authored: Tue Feb 7 14:21:30 2017 +0100 Committer: Reynold Xin <r...@databricks.com> Committed: Tue Feb 7 14:21:30 2017 +0100 ---------------------------------------------------------------------- .../sql/execution/basicPhysicalOperators.scala | 82 ++++++++---- .../apache/spark/sql/DataFrameRangeSuite.scala | 130 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 53 -------- .../InputGeneratedOutputMetricsSuite.scala | 131 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 ++ .../sql/hive/execution/HiveSerDeSuite.scala | 19 +++ 6 files changed, 350 insertions(+), 75 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index fb90799..792fb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) // output attributes should not affect the results override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) @@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") + val numGenerated = metricTerm(ctx, "numGeneratedRows") val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") - ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") - ctx.addMutableState("boolean", overflow, s"$overflow = false;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - val checkEnd = if (step > 0) { - s"$number < $partitionEnd" - } else { - s"$number > $partitionEnd" - } + + // In order to periodically update the metrics without inflicting performance penalty, this + // operator produces elements in batches. After a batch is complete, the metrics are updated + // and a new batch is started. + // In the implementation below, the code in the inner loop is producing all the values + // within a batch, while the code in the outer loop is setting batch parameters and updating + // the metrics. + + // Once number == batchEnd, it's time to progress to the next batch. + val batchEnd = ctx.freshName("batchEnd") + ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + + // How many values should still be generated by this range operator. + val numElementsTodo = ctx.freshName("numElementsTodo") + ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + + // How many values should be generated in the next batch. + val nextBatchTodo = ctx.freshName("nextBatchTodo") + + // The default size of a batch. + val batchSize = 1000L ctx.addNewFunction("initRange", s""" @@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); | $BigInt step = $BigInt.valueOf(${step}L); | $BigInt start = $BigInt.valueOf(${start}L); + | long partitionEnd; | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { @@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } else { | $number = st.longValue(); | } + | $batchEnd = $number; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; + | partitionEnd = Long.MAX_VALUE; | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; + | partitionEnd = Long.MIN_VALUE; | } else { - | $partitionEnd = end.longValue(); + | partitionEnd = end.longValue(); | } | - | $numOutput.add(($partitionEnd - $number) / ${step}L); + | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( + | $BigInt.valueOf($number)); + | $numElementsTodo = startToEnd.divide(step).longValue(); + | if ($numElementsTodo < 0) { + | $numElementsTodo = 0; + | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) { + | $numElementsTodo++; + | } | } """.stripMargin) @@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | initRange(partitionIndex); | } | - | while (!$overflow && $checkEnd) { - | long $value = $number; - | $number += ${step}L; - | if ($number < $value ^ ${step}L < 0) { - | $overflow = true; - | } - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | while (true) { + | while ($number != $batchEnd) { + | long $value = $number; + | $number += ${step}L; + | ${consume(ctx, Seq(ev))} + | if (shouldStop()) return; + | } + | + | long $nextBatchTodo; + | if ($numElementsTodo > ${batchSize}L) { + | $nextBatchTodo = ${batchSize}L; + | $numElementsTodo -= ${batchSize}L; + | } else { + | $nextBatchTodo = $numElementsTodo; + | $numElementsTodo = 0; + | if ($nextBatchTodo == 0) break; + | } + | $numOutput.add($nextBatchTodo); + | $numGenerated.add($nextBatchTodo); + | + | $batchEnd += $nextBatchTodo * ${step}L; | } """.stripMargin } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val numGeneratedRows = longMetric("numGeneratedRows") sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) @@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } numOutputRows += 1 + numGeneratedRows += 1 unsafeRow.setLong(0, ret) unsafeRow } http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala new file mode 100644 index 0000000..6d2d776 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.math.abs +import scala.util.Random + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameRangeSuite extends QueryTest with SharedSQLContext { + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = spark.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = spark.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = spark.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = spark.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = spark.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = spark.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = spark.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = spark.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = spark.range(-1).select("id") + assert(res11.count == 0) + + // using the default slice number + val res12 = spark.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + // difference between range start and end does not fit in a 64-bit integer + val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 + val res13 = spark.range(-n, n, n / 9).select("id") + assert(res13.count == 18) + } + + test("Range with randomized parameters") { + val MAX_NUM_STEPS = 10L * 1000 + + val seed = System.currentTimeMillis() + val random = new Random(seed) + + def randomBound(): Long = { + val n = if (random.nextBoolean()) { + random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) + } else { + random.nextLong() / 2 + } + if (random.nextBoolean()) n else -n + } + + for (l <- 1 to 10) { + val start = randomBound() + val end = randomBound() + val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1 + val stepAbs = (abs(end - start) / numSteps) + 1 + val step = if (start < end) stepAbs else -stepAbs + val partitions = random.nextInt(20) + 1 + + val expCount = (start until end by step).size + val expSum = (start until end by step).sum + + for (codegen <- List(false, true)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val res = spark.range(start, end, step, partitions).toDF("id"). + agg(count("id"), sum("id")).collect() + + withClue(s"seed = $seed start = $start end = $end step = $step partitions = " + + s"$partitions codegen = $codegen") { + assert(!res.isEmpty) + assert(res.head.getLong(0) == expCount) + if (expCount > 0) { + assert(res.head.getLong(1) == expSum) + } + } + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6a190b9..e6338ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) } - test("SPARK-7150 range api") { - // numSlice is greater than length - val res1 = spark.range(0, 10, 1, 15).select("id") - assert(res1.count == 10) - assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - - val res2 = spark.range(3, 15, 3, 2).select("id") - assert(res2.count == 4) - assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - - val res3 = spark.range(1, -2).select("id") - assert(res3.count == 0) - - // start is positive, end is negative, step is negative - val res4 = spark.range(1, -2, -2, 6).select("id") - assert(res4.count == 2) - assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) - - // start, end, step are negative - val res5 = spark.range(-3, -8, -2, 1).select("id") - assert(res5.count == 3) - assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) - - // start, end are negative, step is positive - val res6 = spark.range(-8, -4, 2, 1).select("id") - assert(res6.count == 2) - assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - - val res7 = spark.range(-10, -9, -20, 1).select("id") - assert(res7.count == 0) - - val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") - assert(res8.count == 3) - assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - - val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") - assert(res9.count == 2) - assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) - - // only end provided as argument - val res10 = spark.range(10).select("id") - assert(res10.count == 10) - assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - - val res11 = spark.range(-1).select("id") - assert(res11.count == 0) - - // using the default slice number - val res12 = spark.range(3, 15, 3).select("id") - assert(res12.count == 4) - assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - } - test("SPARK-8621: support empty string column name") { val df = Seq(Tuple1(1)).toDF("").as("t") // We should allow empty string as column name http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala new file mode 100644 index 0000000..ddd7a03 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.File + +import org.scalatest.concurrent.Eventually + +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually { + + test("Range query input/output/generated metrics") { + val numRows = 150L + val numSelectedRows = 100L + val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1). + filter(x => x < numSelectedRows).toDF()) + + assert(res.recordsRead.sum === 0) + assert(res.shuffleRecordsRead.sum === 0) + assert(res.generatedRows === numRows :: Nil) + assert(res.outputRows === numSelectedRows :: numRows :: Nil) + } + + test("Input/output/generated metrics with repartitioning") { + val numRows = 100L + val res = MetricsTestHelper.runAndGetMetrics( + spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF()) + + assert(res.recordsRead.sum === 0) + assert(res.shuffleRecordsRead.sum === numRows) + assert(res.generatedRows === numRows :: Nil) + assert(res.outputRows === 20 :: numRows :: Nil) + } + + test("Input/output/generated metrics with more repartitioning") { + withTempDir { tempDir => + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + val res = MetricsTestHelper.runAndGetMetrics( + spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2) + .toDF() + ) + + assert(res.recordsRead.sum == 10) + assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150) + assert(res.generatedRows == 30 :: Nil) + assert(res.outputRows == 10 :: 30 :: 300 :: Nil) + } + } +} + +object MetricsTestHelper { + case class AggregatedMetricsResult( + recordsRead: List[Long], + shuffleRecordsRead: List[Long], + generatedRows: List[Long], + outputRows: List[Long]) + + private[this] def extractMetricValues( + df: DataFrame, + metricValues: Map[Long, String], + metricName: String): List[Long] = { + df.queryExecution.executedPlan.collect { + case plan if plan.metrics.contains(metricName) => + metricValues(plan.metrics(metricName).id).toLong + }.toList.sorted + } + + def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false): + AggregatedMetricsResult = { + val spark = df.sparkSession + val sparkContext = spark.sparkContext + + var recordsRead = List[Long]() + var shuffleRecordsRead = List[Long]() + val listener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + if (taskEnd.taskMetrics != null) { + recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead :: + recordsRead + shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead :: + shuffleRecordsRead + } + } + } + + val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet + + val prevUseWholeStageCodeGen = + spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) + try { + spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen) + sparkContext.listenerBus.waitUntilEmpty(10000) + sparkContext.addSparkListener(listener) + df.collect() + sparkContext.listenerBus.waitUntilEmpty(10000) + } finally { + spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen) + } + + val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head + val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val outputRes = extractMetricValues(df, metricValues, "numOutputRows") + val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows") + + AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0396254..14fbe9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.MetricsTestHelper import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(e2.contains("User specified schema not supported with `jdbc`")) } + + test("Input/generated/output metrics on JDBC") { + val foobarCnt = spark.table("foobar").count() + val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF()) + assert(res.recordsRead === foobarCnt :: Nil) + assert(res.shuffleRecordsRead.sum === 0) + assert(res.generatedRows.isEmpty) + assert(res.outputRows === foobarCnt :: Nil) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index ec620c2..35c41b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.MetricsTestHelper import org.apache.spark.sql.hive.test.TestHive /** @@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") + + test("Test input/generated/output metrics") { + import TestHive._ + + val episodesCnt = sql("select * from episodes").count() + val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF()) + assert(episodesRes.recordsRead === episodesCnt :: Nil) + assert(episodesRes.shuffleRecordsRead.sum === 0) + assert(episodesRes.generatedRows.isEmpty) + assert(episodesRes.outputRows === episodesCnt :: Nil) + + val serdeinsCnt = sql("select * from serdeins").count() + val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF()) + assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil) + assert(serdeinsRes.shuffleRecordsRead.sum === 0) + assert(serdeinsRes.generatedRows.isEmpty) + assert(serdeinsRes.outputRows === serdeinsCnt :: Nil) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org