Repository: spark Updated Branches: refs/heads/branch-2.2 1fa3c86a7 -> f71aea6a0
[SPARK-20381][SQL] Add SQL metrics of numOutputRows for ObjectHashAggregateExec ## What changes were proposed in this pull request? ObjectHashAggregateExec is missing numOutputRows, add this metrics for it. ## How was this patch tested? Added unit tests for the new metrics. Author: Yucai <yucai...@intel.com> Closes #17678 from yucai/objectAgg_numOutputRows. (cherry picked from commit 41439fd52dd263b9f7d92e608f027f193f461777) Signed-off-by: Xiao Li <gatorsm...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f71aea6a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f71aea6a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f71aea6a Branch: refs/heads/branch-2.2 Commit: f71aea6a0be6eda24623d8563d971687ecd04caf Parents: 1fa3c86 Author: Yucai <yucai...@intel.com> Authored: Fri May 5 09:51:57 2017 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Fri May 5 09:52:10 2017 -0700 ---------------------------------------------------------------------- .../aggregate/ObjectAggregationIterator.scala | 8 ++++++-- .../aggregate/ObjectHashAggregateExec.scala | 3 ++- .../sql/execution/metric/SQLMetricsSuite.scala | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f71aea6a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3a7fcf1..6e47f9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator @@ -39,7 +40,8 @@ class ObjectAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputRows: Iterator[InternalRow], - fallbackCountThreshold: Int) + fallbackCountThreshold: Int, + numOutputRows: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -83,7 +85,9 @@ class ObjectAggregationIterator( override final def next(): UnsafeRow = { val entry = aggBufferIterator.next() - generateOutput(entry.groupingKey, entry.aggregationBuffer) + val res = generateOutput(entry.groupingKey, entry.aggregationBuffer) + numOutputRows += 1 + res } /** http://git-wip-us.apache.org/repos/asf/spark/blob/f71aea6a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fcb7ec..b53521b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -117,7 +117,8 @@ case class ObjectHashAggregateExec( newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, - fallbackCountThreshold) + fallbackCountThreshold, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) http://git-wip-us.apache.org/repos/asf/spark/blob/f71aea6a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2ce7db6..e5442455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("ObjectHashAggregate metrics") { + // Assume the execution plan is + // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> ObjectHashAggregate(nodeId = 0) + val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).agg(collect_set('a)) + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L))) + ) + } + test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org