Github user hvanhovell commented on a diff in the pull request: https://github.com/apache/spark/pull/16960#discussion_r101783584 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala --- @@ -309,4 +314,94 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } + test("range metrics") { + val res1 = InputOutputMetricsHelper.run( + spark.range(30).filter(x => x % 3 == 0).toDF() + ) + assert(res1 === (30L, 0L, 30L) :: Nil) + + val res2 = InputOutputMetricsHelper.run( + spark.range(150).repartition(4).filter(x => x < 10).toDF() + ) + assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil) + + withTempDir { tempDir => + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + val res3 = InputOutputMetricsHelper.run( + spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() + ) + assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + } + } +} + +object InputOutputMetricsHelper { + private class InputOutputMetricsListener extends SparkListener { + private case class MetricsResult( + var recordsRead: Long = 0L, + var shuffleRecordsRead: Long = 0L, + var sumMaxOutputRows: Long = 0L) + + private[this] var stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] + + def reset(): Unit = { + stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] + } + + /** + * Return a list of recorded metrics aggregated per stage. + * + * The list is sorted in the ascending order on the stageId. + * For each recorded stage, the following tuple is returned: + * - sum of inputMetrics.recordsRead for all the tasks in the stage + * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage + * - sum of the highest values of "number of output rows" metric for all the tasks in the stage + */ + def getResults(): List[(Long, Long, Long)] = { + stageIdToMetricsResult.keySet.toList.sorted.map({ stageId => + val res = stageIdToMetricsResult(stageId) + (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)}) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, { MetricsResult() }) + + res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead + + var maxOutputRows = 0L + for (accum <- taskEnd.taskMetrics.externalAccums) { + val info = accum.toInfo(Some(accum.value), None) + if (info.name.toString.contains("number of output rows")) { + info.update match { + case Some(n: Number) => + if (n.longValue() > maxOutputRows) { + maxOutputRows = n.longValue() + } + case _ => // Ignore. + } + } + } + res.sumMaxOutputRows += maxOutputRows + } + } + + // Run df.collect() and return aggregated metrics for each stage. + def run(df: DataFrame): List[(Long, Long, Long)] = { + val spark = df.sparkSession + val sparkContext = spark.sparkContext + val listener = new InputOutputMetricsListener() --- End diff -- Use try...finally here
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org