mbutrovich commented on code in PR #3842:
URL: https://github.com/apache/datafusion-comet/pull/3842#discussion_r3017876444
##########
spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala:
##########
@@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
}
+
+ test("native_datafusion scan reports task-level input metrics matching
Spark") {
+ withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") {
+ // Collect baseline input metrics from vanilla Spark (Comet disabled)
+ val (sparkBytes, sparkRecords) =
collectInputMetrics(CometConf.COMET_ENABLED.key -> "false")
+
+ // Collect input metrics from Comet native_datafusion scan
+ val (cometBytes, cometRecords) = collectInputMetrics(
+ CometConf.COMET_NATIVE_SCAN_IMPL.key ->
CometConf.SCAN_NATIVE_DATAFUSION)
+
+ // Records must match exactly
+ assert(
+ cometRecords == sparkRecords,
+ s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords")
+
+ // Bytes should be in the same ballpark -- both read the same Parquet
file(s),
+ // but the exact byte count can differ due to reader implementation
details
+ // (e.g. footer reads, page headers, buffering granularity).
+ assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
+ assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
+ val ratio = cometBytes.toDouble / sparkBytes.toDouble
+ assert(
+ ratio >= 0.8 && ratio <= 1.2,
+ s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes,
ratio=$ratio")
+ }
+ }
+
+ /**
+ * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns
the aggregated
+ * (bytesRead, recordsRead) across all tasks.
+ */
+ private def collectInputMetrics(confs: (String, String)*): (Long, Long) = {
+ val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics]
+
+ val listener = new SparkListener {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ val im = taskEnd.taskMetrics.inputMetrics
+ inputMetricsList.synchronized {
+ inputMetricsList += im
+ }
+ }
+ }
+
+ spark.sparkContext.addSparkListener(listener)
+ try {
+ // Drain any earlier events
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+
+ withSQLConf(confs: _*) {
+ sql("SELECT * FROM tbl").collect()
Review Comment:
A filter would show the discrepancy/incorrect values when scan isn't the
first child node.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]