0lai0 commented on code in PR #3999:
URL: https://github.com/apache/datafusion-comet/pull/3999#discussion_r3114877914


##########
spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala:
##########
@@ -100,6 +105,73 @@ class CometTaskMetricsSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("native parquet write reports task-level output metrics") {
+    withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") {
+      withTempPath { dir =>
+        val outPath = new File(dir, "written").getAbsolutePath
+        val outputBytes = mutable.ArrayBuffer.empty[Long]
+        val outputRecords = mutable.ArrayBuffer.empty[Long]
+        val targetStageIds = mutable.HashSet.empty[Int]
+        val jobGroupId = 
s"native-write-metrics-${java.util.UUID.randomUUID().toString}"
+
+        val listener = new SparkListener {
+          override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+            val isTargetJob = Option(jobStart.properties)
+              .flatMap(props => 
Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)))
+              .contains(jobGroupId)
+            if (isTargetJob) {
+              targetStageIds.synchronized {
+                targetStageIds ++= jobStart.stageInfos.map(_.stageId)
+              }
+            }
+          }
+
+          override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+            val isTargetStage = targetStageIds.synchronized {
+              targetStageIds.contains(taskEnd.stageId)
+            }
+            if (isTargetStage) {
+              val om = taskEnd.taskMetrics.outputMetrics
+              if (om.bytesWritten > 0) {
+                outputBytes.synchronized {
+                  outputBytes += om.bytesWritten
+                  outputRecords += om.recordsWritten
+                }
+              }
+            }
+          }
+        }
+        spark.sparkContext.addSparkListener(listener)
+
+        try {
+          spark.sparkContext.listenerBus.waitUntilEmpty()
+
+          withSQLConf(
+            CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+            CometConf.COMET_EXEC_ENABLED.key -> "true",
+            CometConf.getOperatorAllowIncompatConfigKey(
+              classOf[DataWritingCommandExec]) -> "true",
+            SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") {
+            spark.sparkContext.setJobGroup(jobGroupId, "native parquet write 
output metrics")
+            try {
+              sql("SELECT * FROM tbl").write.parquet(outPath)
+            } finally {
+              spark.sparkContext.clearJobGroup()
+            }
+          }
+
+          spark.sparkContext.listenerBus.waitUntilEmpty()
+
+          assert(outputBytes.nonEmpty, "No task reported 
outputMetrics.bytesWritten")

Review Comment:
   Updated. Now assert recordsWritten with exact equality, and use the existing 
suite convention for bytes approximation (ratio in 0.7–1.3), consistent with 
other input-metrics checks in CometTaskMetricsSuite.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to