comphead commented on code in PR #3999:
URL: https://github.com/apache/datafusion-comet/pull/3999#discussion_r3113673116


##########
spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala:
##########
@@ -100,6 +105,73 @@ class CometTaskMetricsSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("native parquet write reports task-level output metrics") {
+    withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") {
+      withTempPath { dir =>
+        val outPath = new File(dir, "written").getAbsolutePath
+        val outputBytes = mutable.ArrayBuffer.empty[Long]
+        val outputRecords = mutable.ArrayBuffer.empty[Long]
+        val targetStageIds = mutable.HashSet.empty[Int]
+        val jobGroupId = 
s"native-write-metrics-${java.util.UUID.randomUUID().toString}"
+
+        val listener = new SparkListener {
+          override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+            val isTargetJob = Option(jobStart.properties)
+              .flatMap(props => 
Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)))
+              .contains(jobGroupId)
+            if (isTargetJob) {
+              targetStageIds.synchronized {
+                targetStageIds ++= jobStart.stageInfos.map(_.stageId)
+              }
+            }
+          }
+
+          override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+            val isTargetStage = targetStageIds.synchronized {
+              targetStageIds.contains(taskEnd.stageId)
+            }
+            if (isTargetStage) {
+              val om = taskEnd.taskMetrics.outputMetrics
+              if (om.bytesWritten > 0) {
+                outputBytes.synchronized {
+                  outputBytes += om.bytesWritten
+                  outputRecords += om.recordsWritten
+                }
+              }
+            }
+          }
+        }
+        spark.sparkContext.addSparkListener(listener)
+
+        try {
+          spark.sparkContext.listenerBus.waitUntilEmpty()
+
+          withSQLConf(
+            CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+            CometConf.COMET_EXEC_ENABLED.key -> "true",
+            CometConf.getOperatorAllowIncompatConfigKey(
+              classOf[DataWritingCommandExec]) -> "true",
+            SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") {
+            spark.sparkContext.setJobGroup(jobGroupId, "native parquet write 
output metrics")
+            try {
+              sql("SELECT * FROM tbl").write.parquet(outPath)
+            } finally {
+              spark.sparkContext.clearJobGroup()
+            }
+          }
+
+          spark.sparkContext.listenerBus.waitUntilEmpty()
+
+          assert(outputBytes.nonEmpty, "No task reported 
outputMetrics.bytesWritten")

Review Comment:
   Please check the test for input metrics, we def need to ensure number of 
rows are the same. for bytes we need to have some approximation



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to