This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b3354bde feat: Partially implement file commit protocol for native 
Parquet writes (#2828)
1b3354bde is described below

commit 1b3354bdef01cceb01500b966098da5fc92ec0a9
Author: Andy Grove <[email protected]>
AuthorDate: Tue Dec 2 11:36:05 2025 -0700

    feat: Partially implement file commit protocol for native Parquet writes 
(#2828)
---
 .../core/src/execution/operators/parquet_writer.rs | 100 ++++++---
 native/core/src/execution/planner.rs               |   7 +
 native/proto/src/proto/operator.proto              |   7 +
 .../serde/operator/CometDataWritingCommand.scala   |  27 ++-
 .../spark/sql/comet/CometNativeWriteExec.scala     | 225 +++++++++++++++++++--
 5 files changed, 320 insertions(+), 46 deletions(-)

diff --git a/native/core/src/execution/operators/parquet_writer.rs 
b/native/core/src/execution/operators/parquet_writer.rs
index 5536e30dc..57246abf7 100644
--- a/native/core/src/execution/operators/parquet_writer.rs
+++ b/native/core/src/execution/operators/parquet_writer.rs
@@ -25,7 +25,8 @@ use std::{
     sync::Arc,
 };
 
-use arrow::datatypes::SchemaRef;
+use arrow::datatypes::{Schema, SchemaRef};
+use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
 use datafusion::{
     error::{DataFusionError, Result},
@@ -34,6 +35,7 @@ use datafusion::{
     physical_plan::{
         execution_plan::{Boundedness, EmissionType},
         metrics::{ExecutionPlanMetricsSet, MetricsSet},
+        stream::RecordBatchStreamAdapter,
         DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, 
PlanProperties,
         SendableRecordBatchStream, Statistics,
     },
@@ -52,8 +54,14 @@ use crate::execution::shuffle::CompressionCodec;
 pub struct ParquetWriterExec {
     /// Input execution plan
     input: Arc<dyn ExecutionPlan>,
-    /// Output file path
+    /// Output file path (final destination)
     output_path: String,
+    /// Working directory for temporary files (used by FileCommitProtocol)
+    work_dir: String,
+    /// Job ID for tracking this write operation
+    job_id: Option<String>,
+    /// Task attempt ID for this specific task
+    task_attempt_id: Option<i32>,
     /// Compression codec
     compression: CompressionCodec,
     /// Partition ID (from Spark TaskContext)
@@ -68,9 +76,13 @@ pub struct ParquetWriterExec {
 
 impl ParquetWriterExec {
     /// Create a new ParquetWriterExec
+    #[allow(clippy::too_many_arguments)]
     pub fn try_new(
         input: Arc<dyn ExecutionPlan>,
         output_path: String,
+        work_dir: String,
+        job_id: Option<String>,
+        task_attempt_id: Option<i32>,
         compression: CompressionCodec,
         partition_id: i32,
         column_names: Vec<String>,
@@ -88,6 +100,9 @@ impl ParquetWriterExec {
         Ok(ParquetWriterExec {
             input,
             output_path,
+            work_dir,
+            job_id,
+            task_attempt_id,
             compression,
             partition_id,
             column_names,
@@ -144,7 +159,7 @@ impl ExecutionPlan for ParquetWriterExec {
     }
 
     fn schema(&self) -> SchemaRef {
-        self.input.schema()
+        Arc::new(Schema::empty())
     }
 
     fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
@@ -159,6 +174,9 @@ impl ExecutionPlan for ParquetWriterExec {
             1 => Ok(Arc::new(ParquetWriterExec::try_new(
                 Arc::clone(&children[0]),
                 self.output_path.clone(),
+                self.work_dir.clone(),
+                self.job_id.clone(),
+                self.task_attempt_id,
                 self.compression.clone(),
                 self.partition_id,
                 self.column_names.clone(),
@@ -174,34 +192,36 @@ impl ExecutionPlan for ParquetWriterExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
+        use datafusion::physical_plan::metrics::MetricBuilder;
+
+        // Create metrics for tracking write statistics
+        let files_written = 
MetricBuilder::new(&self.metrics).counter("files_written", partition);
+        let bytes_written = 
MetricBuilder::new(&self.metrics).counter("bytes_written", partition);
+        let rows_written = 
MetricBuilder::new(&self.metrics).counter("rows_written", partition);
+
         let input = self.input.execute(partition, context)?;
-        let input_schema = self.schema();
-        let output_path = self.output_path.clone();
+        let input_schema = self.input.schema();
+        let work_dir = self.work_dir.clone();
+        let task_attempt_id = self.task_attempt_id;
         let compression = self.compression_to_parquet()?;
         let column_names = self.column_names.clone();
 
         assert_eq!(input_schema.fields().len(), column_names.len());
 
-        // Create output schema with correct column names
-        let output_schema = if !column_names.is_empty() {
-            // Replace the generic column names (col_0, col_1, etc.) with the 
actual names
-            let fields: Vec<_> = input_schema
-                .fields()
-                .iter()
-                .enumerate()
-                .map(|(i, field)| 
Arc::new(field.as_ref().clone().with_name(&column_names[i])))
-                .collect();
-            Arc::new(arrow::datatypes::Schema::new(fields))
-        } else {
-            // No column names provided, use input schema as-is
-            Arc::clone(&input_schema)
-        };
+        // Replace the generic column names (col_0, col_1, etc.) with the 
actual names
+        let fields: Vec<_> = input_schema
+            .fields()
+            .iter()
+            .enumerate()
+            .map(|(i, field)| 
Arc::new(field.as_ref().clone().with_name(&column_names[i])))
+            .collect();
+        let output_schema = Arc::new(arrow::datatypes::Schema::new(fields));
 
         // Strip file:// or file: prefix if present
-        let local_path = output_path
+        let local_path = work_dir
             .strip_prefix("file://")
-            .or_else(|| output_path.strip_prefix("file:"))
-            .unwrap_or(&output_path)
+            .or_else(|| work_dir.strip_prefix("file:"))
+            .unwrap_or(&work_dir)
             .to_string();
 
         // Create output directory
@@ -213,7 +233,15 @@ impl ExecutionPlan for ParquetWriterExec {
         })?;
 
         // Generate part file name for this partition
-        let part_file = format!("{}/part-{:05}.parquet", local_path, 
self.partition_id);
+        // If using FileCommitProtocol (work_dir is set), include 
task_attempt_id in the filename
+        let part_file = if let Some(attempt_id) = task_attempt_id {
+            format!(
+                "{}/part-{:05}-{:05}.parquet",
+                local_path, self.partition_id, attempt_id
+            )
+        } else {
+            format!("{}/part-{:05}.parquet", local_path, self.partition_id)
+        };
 
         // Create the Parquet file
         let file = File::create(&part_file).map_err(|e| {
@@ -237,13 +265,16 @@ impl ExecutionPlan for ParquetWriterExec {
         // Write batches
         let write_task = async move {
             let mut stream = input;
+            let mut total_rows = 0i64;
 
             while let Some(batch_result) = stream.try_next().await.transpose() 
{
                 let batch = batch_result?;
 
+                // Track row count
+                total_rows += batch.num_rows() as i64;
+
                 // Rename columns in the batch to match output schema
                 let renamed_batch = if !column_names.is_empty() {
-                    use arrow::record_batch::RecordBatch;
                     RecordBatch::try_new(Arc::clone(&schema_for_write), 
batch.columns().to_vec())
                         .map_err(|e| {
                             DataFusionError::Execution(format!(
@@ -264,14 +295,29 @@ impl ExecutionPlan for ParquetWriterExec {
                 DataFusionError::Execution(format!("Failed to close writer: 
{}", e))
             })?;
 
+            // Get file size
+            let file_size = std::fs::metadata(&part_file)
+                .map(|m| m.len() as i64)
+                .unwrap_or(0);
+
+            // Update metrics with write statistics
+            files_written.add(1);
+            bytes_written.add(file_size as usize);
+            rows_written.add(total_rows as usize);
+
+            // Log metadata for debugging
+            eprintln!(
+                "Wrote Parquet file: path={}, size={}, rows={}",
+                part_file, file_size, total_rows
+            );
+
             // Return empty stream to indicate completion
             Ok::<_, DataFusionError>(futures::stream::empty())
         };
 
-        // Execute the write task and convert to a stream
-        use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+        // Execute the write task and create a stream that does not return any 
batches
         Ok(Box::pin(RecordBatchStreamAdapter::new(
-            output_schema,
+            self.schema(),
             futures::stream::once(write_task).try_flatten(),
         )))
     }
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index b0746a6f8..ccbd0b250 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1466,6 +1466,13 @@ impl PhysicalPlanner {
                 let parquet_writer = Arc::new(ParquetWriterExec::try_new(
                     Arc::clone(&child.native_plan),
                     writer.output_path.clone(),
+                    writer
+                        .work_dir
+                        .as_ref()
+                        .expect("work_dir is provided")
+                        .clone(),
+                    writer.job_id.clone(),
+                    writer.task_attempt_id,
                     codec,
                     self.partition,
                     writer.column_names.clone(),
diff --git a/native/proto/src/proto/operator.proto 
b/native/proto/src/proto/operator.proto
index a95832709..f09695b7c 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -241,6 +241,13 @@ message ParquetWriter {
   string output_path = 1;
   CompressionCodec compression = 2;
   repeated string column_names = 4;
+  // Working directory for temporary files (used by FileCommitProtocol)
+  // If not set, files are written directly to output_path
+  optional string work_dir = 5;
+  // Job ID for tracking this write operation
+  optional string job_id = 6;
+  // Task attempt ID for this specific task
+  optional int32 task_attempt_id = 7;
 }
 
 enum AggregateMode {
diff --git 
a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
 
b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
index 46d01c887..7fdf05521 100644
--- 
a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
+++ 
b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
@@ -23,6 +23,7 @@ import java.util.Locale
 
 import scala.jdk.CollectionConverters._
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec}
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import 
org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, 
WriteFilesExec}
@@ -129,6 +130,8 @@ object CometDataWritingCommand extends 
CometOperatorSerde[DataWritingCommandExec
         .setOutputPath(outputPath)
         .setCompression(codec)
         .addAllColumnNames(cmd.query.output.map(_.name).asJava)
+        // Note: work_dir, job_id, and task_attempt_id will be set at 
execution time
+        // in CometNativeWriteExec, as they depend on the Spark task context
         .build()
 
       val writerOperator = Operator
@@ -163,7 +166,29 @@ object CometDataWritingCommand extends 
CometOperatorSerde[DataWritingCommandExec
         other
     }
 
-    CometNativeWriteExec(nativeOp, childPlan, outputPath)
+    // Create FileCommitProtocol for atomic writes
+    val jobId = java.util.UUID.randomUUID().toString
+    val committer =
+      try {
+        // Use Spark's SQLHadoopMapReduceCommitProtocol
+        val committerClass =
+          
classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol]
+        val constructor =
+          committerClass.getConstructor(classOf[String], classOf[String], 
classOf[Boolean])
+        Some(
+          constructor
+            .newInstance(
+              jobId,
+              outputPath,
+              java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for 
now
+            )
+            .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol])
+      } catch {
+        case e: Exception =>
+          throw new SparkException(s"Could not instantiate FileCommitProtocol: 
${e.getMessage}")
+      }
+
+    CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId)
   }
 
   private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = {
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
index 2617e8c60..f153a691e 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
@@ -19,54 +19,122 @@
 
 package org.apache.spark.sql.comet
 
+import java.io.ByteArrayOutputStream
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, 
TaskID, TaskType}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.spark.internal.io.FileCommitProtocol
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
 
 import org.apache.comet.CometExecIterator
 import org.apache.comet.serde.OperatorOuterClass.Operator
 
 /**
- * Comet physical operator for native Parquet write operations.
+ * Comet physical operator for native Parquet write operations with 
FileCommitProtocol support.
+ *
+ * This operator writes data to Parquet files using the native Comet engine. 
It integrates with
+ * Spark's FileCommitProtocol to provide atomic writes with proper staging and 
commit semantics.
  *
- * This operator writes data to Parquet files using the native Comet engine. 
It wraps the child
- * operator and adds a ParquetWriter operator on top.
+ * The implementation includes support for Spark's file commit protocol 
through work_dir, job_id,
+ * and task_attempt_id parameters that can be set in the operator. When 
work_dir is set, files are
+ * written to a temporary location that can be atomically committed later.
  *
  * @param nativeOp
- *   The native operator representing the write operation
+ *   The native operator representing the write operation (template, will be 
modified per task)
  * @param child
  *   The child operator providing the data to write
  * @param outputPath
  *   The path where the Parquet file will be written
+ * @param committer
+ *   FileCommitProtocol for atomic writes. If None, files are written directly.
+ * @param jobTrackerID
+ *   Unique identifier for this write job
  */
-case class CometNativeWriteExec(nativeOp: Operator, child: SparkPlan, 
outputPath: String)
+case class CometNativeWriteExec(
+    nativeOp: Operator,
+    child: SparkPlan,
+    outputPath: String,
+    committer: Option[FileCommitProtocol] = None,
+    jobTrackerID: String = Utils.createTempDir().getName)
     extends CometNativeExec
     with UnaryExecNode {
 
   override def originalPlan: SparkPlan = child
 
+  // Accumulator to collect TaskCommitMessages from all tasks
+  // Must be eagerly initialized on driver, not lazy
+  @transient private val taskCommitMessagesAccum =
+    
sparkContext.collectionAccumulator[FileCommitProtocol.TaskCommitMessage]("taskCommitMessages")
+
   override def serializedPlanOpt: SerializedPlan = {
-    val outputStream = new java.io.ByteArrayOutputStream()
+    val outputStream = new ByteArrayOutputStream()
     nativeOp.writeTo(outputStream)
     outputStream.close()
     SerializedPlan(Some(outputStream.toByteArray))
   }
 
-  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+  override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     copy(child = newChild)
 
   override def nodeName: String = "CometNativeWrite"
 
+  override lazy val metrics: Map[String, SQLMetric] = Map(
+    "files_written" -> SQLMetrics.createMetric(sparkContext, "number of 
written data files"),
+    "bytes_written" -> SQLMetrics.createSizeMetric(sparkContext, "written 
data"),
+    "rows_written" -> SQLMetrics.createMetric(sparkContext, "number of written 
rows"))
+
   override def doExecute(): RDD[InternalRow] = {
-    // Execute the native write
+    // Setup job if committer is present
+    committer.foreach { c =>
+      val jobContext = createJobContext()
+      c.setupJob(jobContext)
+    }
+
+    // Execute the native write with commit protocol
     val resultRDD = doExecuteColumnar()
-    // Convert to empty InternalRow RDD (write operations typically return 
empty results)
-    resultRDD.mapPartitions { iter =>
-      // Consume all batches (they should be empty)
-      iter.foreach(_.close())
-      Iterator.empty
+
+    // Force execution by consuming all batches
+    resultRDD
+      .mapPartitions { iter =>
+        iter.foreach(_.close())
+        Iterator.empty
+      }
+      .count()
+
+    // Extract write statistics from metrics
+    val filesWritten = metrics("files_written").value
+    val bytesWritten = metrics("bytes_written").value
+    val rowsWritten = metrics("rows_written").value
+
+    // Collect TaskCommitMessages from accumulator
+    val commitMessages = taskCommitMessagesAccum.value.asScala.toSeq
+
+    // Commit job with collected TaskCommitMessages
+    committer.foreach { c =>
+      val jobContext = createJobContext()
+      try {
+        c.commitJob(jobContext, commitMessages)
+        logInfo(
+          s"Successfully committed write job to $outputPath: " +
+            s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows")
+      } catch {
+        case e: Exception =>
+          logError("Failed to commit job, aborting", e)
+          c.abortJob(jobContext)
+          throw e
+      }
     }
+
+    // Return empty RDD as write operations don't return data
+    sparkContext.emptyRDD[InternalRow]
   }
 
   override def doExecuteColumnar(): RDD[ColumnarBatch] = {
@@ -86,27 +154,148 @@ case class CometNativeWriteExec(nativeOp: Operator, 
child: SparkPlan, outputPath
     // Capture metadata before the transformation
     val numPartitions = childRDD.getNumPartitions
     val numOutputCols = child.output.length
+    val capturedCommitter = committer
+    val capturedJobTrackerID = jobTrackerID
+    val capturedNativeOp = nativeOp
+    val capturedAccumulator = taskCommitMessagesAccum // Capture accumulator 
for use in tasks
 
-    // Execute native write operation
+    // Execute native write operation with task-level commit protocol
     childRDD.mapPartitionsInternal { iter =>
+      val partitionId = org.apache.spark.TaskContext.getPartitionId()
+      val taskAttemptId = org.apache.spark.TaskContext.get().taskAttemptId()
+
+      // Setup task-level commit protocol if provided
+      val (workDir, taskContext, commitMsg) = capturedCommitter
+        .map { committer =>
+          val taskContext =
+            createTaskContext(capturedJobTrackerID, partitionId, 
taskAttemptId.toInt)
+
+          // Setup task - this creates the temporary working directory
+          committer.setupTask(taskContext)
+
+          // Get the work directory for temp files
+          val workPath = committer.newTaskTempFile(taskContext, None, "")
+          val workDir = new Path(workPath).getParent.toString
+
+          (Some(workDir), Some((committer, taskContext)), null)
+        }
+        .getOrElse((None, None, null))
+
+      // Modify the native operator to include task-specific parameters
+      val modifiedNativeOp = if (workDir.isDefined) {
+        val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder
+          .setWorkDir(workDir.get)
+          .setJobId(capturedJobTrackerID)
+          .setTaskAttemptId(taskAttemptId.toInt)
+          .build()
+
+        capturedNativeOp.toBuilder.setParquetWriter(parquetWriter).build()
+      } else {
+        capturedNativeOp
+      }
+
       val nativeMetrics = CometMetricNode.fromCometPlan(this)
 
-      val outputStream = new java.io.ByteArrayOutputStream()
-      nativeOp.writeTo(outputStream)
+      val outputStream = new ByteArrayOutputStream()
+      modifiedNativeOp.writeTo(outputStream)
       outputStream.close()
       val planBytes = outputStream.toByteArray
 
-      new CometExecIterator(
+      val execIterator = new CometExecIterator(
         CometExec.newIterId,
         Seq(iter),
         numOutputCols,
         planBytes,
         nativeMetrics,
         numPartitions,
-        org.apache.spark.TaskContext.getPartitionId(),
+        partitionId,
         None,
         Seq.empty)
 
+      // Wrap the iterator to handle task commit/abort and capture 
TaskCommitMessage
+      new Iterator[ColumnarBatch] {
+        private var completed = false
+        private var thrownException: Option[Throwable] = None
+
+        override def hasNext: Boolean = {
+          val result =
+            try {
+              execIterator.hasNext
+            } catch {
+              case e: Throwable =>
+                thrownException = Some(e)
+                handleTaskEnd()
+                throw e
+            }
+
+          if (!result && !completed) {
+            handleTaskEnd()
+          }
+
+          result
+        }
+
+        override def next(): ColumnarBatch = {
+          try {
+            execIterator.next()
+          } catch {
+            case e: Throwable =>
+              thrownException = Some(e)
+              handleTaskEnd()
+              throw e
+          }
+        }
+
+        private def handleTaskEnd(): Unit = {
+          if (!completed) {
+            completed = true
+
+            // Handle commit or abort based on whether an exception was thrown
+            taskContext.foreach { case (committer, ctx) =>
+              try {
+                if (thrownException.isEmpty) {
+                  // Commit the task and add message to accumulator
+                  val message = committer.commitTask(ctx)
+                  capturedAccumulator.add(message)
+                  logInfo(s"Task ${ctx.getTaskAttemptID} committed 
successfully")
+                } else {
+                  // Abort the task
+                  committer.abortTask(ctx)
+                  val exMsg = thrownException.get.getMessage
+                  logWarning(s"Task ${ctx.getTaskAttemptID} aborted due to 
exception: $exMsg")
+                }
+              } catch {
+                case e: Exception =>
+                  // Log the commit/abort exception but don't mask the 
original exception
+                  logError(s"Error during task commit/abort: ${e.getMessage}", 
e)
+                  if (thrownException.isEmpty) {
+                    // If no original exception, propagate the commit/abort 
exception
+                    throw e
+                  }
+              }
+            }
+          }
+        }
+      }
     }
   }
+
+  /** Create a JobContext for the write job */
+  private def createJobContext(): Job = {
+    val job = Job.getInstance()
+    job.setJobID(new org.apache.hadoop.mapreduce.JobID(jobTrackerID, 0))
+    job
+  }
+
+  /** Create a TaskAttemptContext for a specific task */
+  private def createTaskContext(
+      jobId: String,
+      partitionId: Int,
+      attemptNumber: Int): TaskAttemptContext = {
+    val job = Job.getInstance()
+    val taskAttemptID = new TaskAttemptID(
+      new TaskID(new org.apache.hadoop.mapreduce.JobID(jobId, 0), 
TaskType.REDUCE, partitionId),
+      attemptNumber)
+    new TaskAttemptContextImpl(job.getConfiguration, taskAttemptID)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to