This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ae651554 Avoid duplicated writer nodes when AQE enabled (#2982)
0ae651554 is described below

commit 0ae65155432b95f00c4ccbd58609eb271c591ba2
Author: Oleks V <[email protected]>
AuthorDate: Tue Dec 23 18:28:14 2025 -0800

    Avoid duplicated writer nodes when AQE enabled (#2982)
    
    * feat: Avoid duplicated write nodes for AQE execution
---
 .../org/apache/comet/rules/CometExecRule.scala     |  9 ++++
 .../comet/parquet/CometParquetWriterSuite.scala    | 55 +++++++++++++++++-----
 2 files changed, 52 insertions(+), 12 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index ed48e36f0..bb4ce879d 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.command.{DataWritingCommandExec, 
ExecutedCommandExec}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
 import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
 import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       case op if shouldApplySparkToColumnar(conf, op) =>
         convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
 
+      // AQE reoptimization looks for `DataWritingCommandExec` or 
`WriteFilesExec`
+      // if there is none it would reinsert write nodes, and since Comet remap 
those nodes
+      // to Comet counterparties the write nodes are twice to the plan.
+      // Checking if AQE inserted another write Command on top of existing 
write command
+      case _ @DataWritingCommandExec(_, w: WriteFilesExec)
+          if w.child.isInstanceOf[CometNativeWriteExec] =>
+        w.child
+
       case op: DataWritingCommandExec =>
         convertToComet(op, CometDataWritingCommand).getOrElse(op)
 
diff --git 
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala 
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
index 2ea697fd4..3ae7f949a 100644
--- 
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
+++ 
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
@@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {
 
   private def writeWithCometNativeWriteExec(
       inputPath: String,
-      outputPath: String): Option[QueryExecution] = {
+      outputPath: String,
+      num_partitions: Option[Int] = None): Option[QueryExecution] = {
     val df = spark.read.parquet(inputPath)
 
     // Use a listener to capture the execution plan during write
@@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
     spark.listenerManager.register(listener)
 
     try {
-      // Perform native write
-      df.write.parquet(outputPath)
+      // Perform native write with optional partitioning
+      num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
 
       // Wait for listener to be called with timeout
       val maxWaitTimeMs = 15000
@@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase {
         s"Listener was not called within ${maxWaitTimeMs}ms - no execution 
plan captured")
 
       capturedPlan.foreach { qe =>
-        val executedPlan = qe.executedPlan
-        val hasNativeWrite = executedPlan.exists {
-          case _: CometNativeWriteExec => true
+        val executedPlan = stripAQEPlan(qe.executedPlan)
+
+        // Count CometNativeWriteExec instances in the plan
+        var nativeWriteCount = 0
+        executedPlan.foreach {
+          case _: CometNativeWriteExec =>
+            nativeWriteCount += 1
           case d: DataWritingCommandExec =>
-            d.child.exists {
-              case _: CometNativeWriteExec => true
-              case _ => false
+            d.child.foreach {
+              case _: CometNativeWriteExec =>
+                nativeWriteCount += 1
+              case _ =>
             }
-          case _ => false
+          case _ =>
         }
 
         assert(
-          hasNativeWrite,
-          s"Expected CometNativeWriteExec in the plan, but 
got:\n${executedPlan.treeString}")
+          nativeWriteCount == 1,
+          s"Expected exactly one CometNativeWriteExec in the plan, but found 
$nativeWriteCount:\n${executedPlan.treeString}")
       }
     } finally {
       spark.listenerManager.unregister(listener)
@@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase {
       }
     }
   }
+
+  test("basic parquet write with repartition") {
+    withTempPath { dir =>
+      // Create test data and write it to a temp parquet file first
+      withTempPath { inputDir =>
+        val inputPath = createTestData(inputDir)
+        Seq(true, false).foreach(adaptive => {
+          // Create a new output path for each AQE value
+          val outputPath = new File(dir, 
s"output_aqe_$adaptive.parquet").getAbsolutePath
+
+          withSQLConf(
+            CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+            "spark.sql.adaptive.enabled" -> adaptive.toString,
+            SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+            CometConf.getOperatorAllowIncompatConfigKey(
+              classOf[DataWritingCommandExec]) -> "true",
+            CometConf.COMET_EXEC_ENABLED.key -> "true") {
+
+            writeWithCometNativeWriteExec(inputPath, outputPath, Some(10))
+            verifyWrittenFile(outputPath)
+          }
+        })
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to