This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new fe49e4074 feat: CometNativeWriteExec support with native scan as a 
child (#2839)
fe49e4074 is described below

commit fe49e4074857ff398093516dbbeb551cbc5d3d07
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Dec 4 11:32:44 2025 -0500

    feat: CometNativeWriteExec support with native scan as a child (#2839)
---
 .../org/apache/comet/rules/CometExecRule.scala     |  11 +
 .../comet/parquet/CometParquetWriterSuite.scala    | 221 +++++++++++++--------
 2 files changed, 144 insertions(+), 88 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index a92082ae1..9152b9f78 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -536,6 +536,17 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
             firstNativeOp = true
           }
 
+          // CometNativeWriteExec is special: it has two separate plans:
+          // 1. A protobuf plan (nativeOp) describing the write operation
+          // 2. A Spark plan (child) that produces the data to write
+          // The serializedPlanOpt is a def that always returns Some(...) by 
serializing
+          // nativeOp on-demand, so it doesn't need convertBlock(). However, 
its child
+          // (e.g., CometNativeScanExec) may need its own serialization. Reset 
the flag
+          // so children can start their own native execution blocks.
+          if (op.isInstanceOf[CometNativeWriteExec]) {
+            firstNativeOp = true
+          }
+
           newPlan
         case op =>
           firstNativeOp = true
diff --git 
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala 
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
index e4b8b5385..2ea697fd4 100644
--- 
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
+++ 
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
@@ -24,7 +24,7 @@ import java.io.File
 import scala.util.Random
 
 import org.apache.spark.sql.{CometTestBase, DataFrame}
-import org.apache.spark.sql.comet.CometNativeWriteExec
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.internal.SQLConf
@@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, 
FuzzDataGenerator, SchemaGenOpt
 
 class CometParquetWriterSuite extends CometTestBase {
 
-  test("basic parquet write") {
-    // no support for fully native scan as input yet
-    assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != 
CometConf.SCAN_NATIVE_DATAFUSION)
+  private def createTestData(inputDir: File): String = {
+    val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
+    val schema = FuzzDataGenerator.generateSchema(
+      SchemaGenOptions(generateArray = false, generateStruct = false, 
generateMap = false))
+    val df = FuzzDataGenerator.generateDataFrame(
+      new Random(42),
+      spark,
+      schema,
+      1000,
+      DataGenOptions(generateNegativeZero = false))
+    withSQLConf(
+      CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
+      SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
+      df.write.parquet(inputPath)
+    }
+    inputPath
+  }
+
+  private def writeWithCometNativeWriteExec(
+      inputPath: String,
+      outputPath: String): Option[QueryExecution] = {
+    val df = spark.read.parquet(inputPath)
+
+    // Use a listener to capture the execution plan during write
+    var capturedPlan: Option[QueryExecution] = None
+
+    val listener = new org.apache.spark.sql.util.QueryExecutionListener {
+      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
+        // Capture plans from write operations
+        if (funcName == "save" || funcName.contains("command")) {
+          capturedPlan = Some(qe)
+        }
+      }
+
+      override def onFailure(
+          funcName: String,
+          qe: QueryExecution,
+          exception: Exception): Unit = {}
+    }
+
+    spark.listenerManager.register(listener)
+
+    try {
+      // Perform native write
+      df.write.parquet(outputPath)
+
+      // Wait for listener to be called with timeout
+      val maxWaitTimeMs = 15000
+      val checkIntervalMs = 100
+      val maxIterations = maxWaitTimeMs / checkIntervalMs
+      var iterations = 0
+
+      while (capturedPlan.isEmpty && iterations < maxIterations) {
+        Thread.sleep(checkIntervalMs)
+        iterations += 1
+      }
+
+      // Verify that CometNativeWriteExec was used
+      assert(
+        capturedPlan.isDefined,
+        s"Listener was not called within ${maxWaitTimeMs}ms - no execution 
plan captured")
+
+      capturedPlan.foreach { qe =>
+        val executedPlan = qe.executedPlan
+        val hasNativeWrite = executedPlan.exists {
+          case _: CometNativeWriteExec => true
+          case d: DataWritingCommandExec =>
+            d.child.exists {
+              case _: CometNativeWriteExec => true
+              case _ => false
+            }
+          case _ => false
+        }
+
+        assert(
+          hasNativeWrite,
+          s"Expected CometNativeWriteExec in the plan, but 
got:\n${executedPlan.treeString}")
+      }
+    } finally {
+      spark.listenerManager.unregister(listener)
+    }
+    capturedPlan
+  }
+
+  private def verifyWrittenFile(outputPath: String): Unit = {
+    // Verify the data was written correctly
+    val resultDf = spark.read.parquet(outputPath)
+    assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
+
+    // Verify multiple part files were created
+    val outputDir = new File(outputPath)
+    val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
+    // With 1000 rows and default parallelism, we should get multiple 
partitions
+    assert(partFiles.length > 1, "Expected multiple part files to be created")
+
+    // read with and without Comet and compare
+    var sparkDf: DataFrame = null
+    var cometDf: DataFrame = null
+    withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
+      sparkDf = spark.read.parquet(outputPath)
+    }
+    withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
+      cometDf = spark.read.parquet(outputPath)
+    }
+    checkAnswer(sparkDf, cometDf)
+  }
 
+  test("basic parquet write") {
     withTempPath { dir =>
       val outputPath = new File(dir, "output.parquet").getAbsolutePath
 
       // Create test data and write it to a temp parquet file first
       withTempPath { inputDir =>
-        val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
-        val schema = FuzzDataGenerator.generateSchema(
-          SchemaGenOptions(generateArray = false, generateStruct = false, 
generateMap = false))
-        val df = FuzzDataGenerator.generateDataFrame(
-          new Random(42),
-          spark,
-          schema,
-          1000,
-          DataGenOptions(generateNegativeZero = false))
-        withSQLConf(
-          CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
-          SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
-          df.write.parquet(inputPath)
-        }
+        val inputPath = createTestData(inputDir)
 
         withSQLConf(
           CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
           SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
           
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> 
"true",
           CometConf.COMET_EXEC_ENABLED.key -> "true") {
-          val df = spark.read.parquet(inputPath)
-
-          // Use a listener to capture the execution plan during write
-          var capturedPlan: Option[QueryExecution] = None
-
-          val listener = new org.apache.spark.sql.util.QueryExecutionListener {
-            override def onSuccess(
-                funcName: String,
-                qe: QueryExecution,
-                durationNs: Long): Unit = {
-              // Capture plans from write operations
-              if (funcName == "save" || funcName.contains("command")) {
-                capturedPlan = Some(qe)
-              }
-            }
 
-            override def onFailure(
-                funcName: String,
-                qe: QueryExecution,
-                exception: Exception): Unit = {}
-          }
+          writeWithCometNativeWriteExec(inputPath, outputPath)
 
-          spark.listenerManager.register(listener)
-
-          try {
-            // Perform native write
-            df.write.parquet(outputPath)
+          verifyWrittenFile(outputPath)
+        }
+      }
+    }
+  }
 
-            // Wait for listener to be called with timeout
-            val maxWaitTimeMs = 15000
-            val checkIntervalMs = 100
-            val maxIterations = maxWaitTimeMs / checkIntervalMs
-            var iterations = 0
+  test("basic parquet write with native scan child") {
+    withTempPath { dir =>
+      val outputPath = new File(dir, "output.parquet").getAbsolutePath
 
-            while (capturedPlan.isEmpty && iterations < maxIterations) {
-              Thread.sleep(checkIntervalMs)
-              iterations += 1
-            }
+      // Create test data and write it to a temp parquet file first
+      withTempPath { inputDir =>
+        val inputPath = createTestData(inputDir)
 
-            // Verify that CometNativeWriteExec was used
-            assert(
-              capturedPlan.isDefined,
-              s"Listener was not called within ${maxWaitTimeMs}ms - no 
execution plan captured")
+        withSQLConf(
+          CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+          SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+          
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> 
"true",
+          CometConf.COMET_EXEC_ENABLED.key -> "true") {
 
+          withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> 
"native_datafusion") {
+            val capturedPlan = writeWithCometNativeWriteExec(inputPath, 
outputPath)
             capturedPlan.foreach { qe =>
               val executedPlan = qe.executedPlan
-              val hasNativeWrite = executedPlan.exists {
-                case _: CometNativeWriteExec => true
-                case d: DataWritingCommandExec =>
-                  d.child.exists {
-                    case _: CometNativeWriteExec => true
-                    case _ => false
-                  }
+              val hasNativeScan = executedPlan.exists {
+                case _: CometNativeScanExec => true
                 case _ => false
               }
 
               assert(
-                hasNativeWrite,
-                s"Expected CometNativeWriteExec in the plan, but 
got:\n${executedPlan.treeString}")
+                hasNativeScan,
+                s"Expected CometNativeScanExec in the plan, but 
got:\n${executedPlan.treeString}")
             }
-          } finally {
-            spark.listenerManager.unregister(listener)
-          }
 
-          // Verify the data was written correctly
-          val resultDf = spark.read.parquet(outputPath)
-          assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
-
-          // Verify multiple part files were created
-          val outputDir = new File(outputPath)
-          val partFiles = 
outputDir.listFiles().filter(_.getName.startsWith("part-"))
-          // With 1000 rows and default parallelism, we should get multiple 
partitions
-          assert(partFiles.length > 1, "Expected multiple part files to be 
created")
-
-          // read with and without Comet and compare
-          var sparkDf: DataFrame = null
-          var cometDf: DataFrame = null
-          withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
-            sparkDf = spark.read.parquet(outputPath)
-          }
-          withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
-            cometDf = spark.read.parquet(outputPath)
+            verifyWrittenFile(outputPath)
           }
-          checkAnswer(sparkDf, cometDf)
         }
       }
     }
   }
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to