This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 1b3354bde feat: Partially implement file commit protocol for native
Parquet writes (#2828)
1b3354bde is described below
commit 1b3354bdef01cceb01500b966098da5fc92ec0a9
Author: Andy Grove <[email protected]>
AuthorDate: Tue Dec 2 11:36:05 2025 -0700
feat: Partially implement file commit protocol for native Parquet writes
(#2828)
---
.../core/src/execution/operators/parquet_writer.rs | 100 ++++++---
native/core/src/execution/planner.rs | 7 +
native/proto/src/proto/operator.proto | 7 +
.../serde/operator/CometDataWritingCommand.scala | 27 ++-
.../spark/sql/comet/CometNativeWriteExec.scala | 225 +++++++++++++++++++--
5 files changed, 320 insertions(+), 46 deletions(-)
diff --git a/native/core/src/execution/operators/parquet_writer.rs
b/native/core/src/execution/operators/parquet_writer.rs
index 5536e30dc..57246abf7 100644
--- a/native/core/src/execution/operators/parquet_writer.rs
+++ b/native/core/src/execution/operators/parquet_writer.rs
@@ -25,7 +25,8 @@ use std::{
sync::Arc,
};
-use arrow::datatypes::SchemaRef;
+use arrow::datatypes::{Schema, SchemaRef};
+use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::{
error::{DataFusionError, Result},
@@ -34,6 +35,7 @@ use datafusion::{
physical_plan::{
execution_plan::{Boundedness, EmissionType},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
+ stream::RecordBatchStreamAdapter,
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties,
PlanProperties,
SendableRecordBatchStream, Statistics,
},
@@ -52,8 +54,14 @@ use crate::execution::shuffle::CompressionCodec;
pub struct ParquetWriterExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
- /// Output file path
+ /// Output file path (final destination)
output_path: String,
+ /// Working directory for temporary files (used by FileCommitProtocol)
+ work_dir: String,
+ /// Job ID for tracking this write operation
+ job_id: Option<String>,
+ /// Task attempt ID for this specific task
+ task_attempt_id: Option<i32>,
/// Compression codec
compression: CompressionCodec,
/// Partition ID (from Spark TaskContext)
@@ -68,9 +76,13 @@ pub struct ParquetWriterExec {
impl ParquetWriterExec {
/// Create a new ParquetWriterExec
+ #[allow(clippy::too_many_arguments)]
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
output_path: String,
+ work_dir: String,
+ job_id: Option<String>,
+ task_attempt_id: Option<i32>,
compression: CompressionCodec,
partition_id: i32,
column_names: Vec<String>,
@@ -88,6 +100,9 @@ impl ParquetWriterExec {
Ok(ParquetWriterExec {
input,
output_path,
+ work_dir,
+ job_id,
+ task_attempt_id,
compression,
partition_id,
column_names,
@@ -144,7 +159,7 @@ impl ExecutionPlan for ParquetWriterExec {
}
fn schema(&self) -> SchemaRef {
- self.input.schema()
+ Arc::new(Schema::empty())
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
@@ -159,6 +174,9 @@ impl ExecutionPlan for ParquetWriterExec {
1 => Ok(Arc::new(ParquetWriterExec::try_new(
Arc::clone(&children[0]),
self.output_path.clone(),
+ self.work_dir.clone(),
+ self.job_id.clone(),
+ self.task_attempt_id,
self.compression.clone(),
self.partition_id,
self.column_names.clone(),
@@ -174,34 +192,36 @@ impl ExecutionPlan for ParquetWriterExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
+ use datafusion::physical_plan::metrics::MetricBuilder;
+
+ // Create metrics for tracking write statistics
+ let files_written =
MetricBuilder::new(&self.metrics).counter("files_written", partition);
+ let bytes_written =
MetricBuilder::new(&self.metrics).counter("bytes_written", partition);
+ let rows_written =
MetricBuilder::new(&self.metrics).counter("rows_written", partition);
+
let input = self.input.execute(partition, context)?;
- let input_schema = self.schema();
- let output_path = self.output_path.clone();
+ let input_schema = self.input.schema();
+ let work_dir = self.work_dir.clone();
+ let task_attempt_id = self.task_attempt_id;
let compression = self.compression_to_parquet()?;
let column_names = self.column_names.clone();
assert_eq!(input_schema.fields().len(), column_names.len());
- // Create output schema with correct column names
- let output_schema = if !column_names.is_empty() {
- // Replace the generic column names (col_0, col_1, etc.) with the
actual names
- let fields: Vec<_> = input_schema
- .fields()
- .iter()
- .enumerate()
- .map(|(i, field)|
Arc::new(field.as_ref().clone().with_name(&column_names[i])))
- .collect();
- Arc::new(arrow::datatypes::Schema::new(fields))
- } else {
- // No column names provided, use input schema as-is
- Arc::clone(&input_schema)
- };
+ // Replace the generic column names (col_0, col_1, etc.) with the
actual names
+ let fields: Vec<_> = input_schema
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(i, field)|
Arc::new(field.as_ref().clone().with_name(&column_names[i])))
+ .collect();
+ let output_schema = Arc::new(arrow::datatypes::Schema::new(fields));
// Strip file:// or file: prefix if present
- let local_path = output_path
+ let local_path = work_dir
.strip_prefix("file://")
- .or_else(|| output_path.strip_prefix("file:"))
- .unwrap_or(&output_path)
+ .or_else(|| work_dir.strip_prefix("file:"))
+ .unwrap_or(&work_dir)
.to_string();
// Create output directory
@@ -213,7 +233,15 @@ impl ExecutionPlan for ParquetWriterExec {
})?;
// Generate part file name for this partition
- let part_file = format!("{}/part-{:05}.parquet", local_path,
self.partition_id);
+ // If using FileCommitProtocol (work_dir is set), include
task_attempt_id in the filename
+ let part_file = if let Some(attempt_id) = task_attempt_id {
+ format!(
+ "{}/part-{:05}-{:05}.parquet",
+ local_path, self.partition_id, attempt_id
+ )
+ } else {
+ format!("{}/part-{:05}.parquet", local_path, self.partition_id)
+ };
// Create the Parquet file
let file = File::create(&part_file).map_err(|e| {
@@ -237,13 +265,16 @@ impl ExecutionPlan for ParquetWriterExec {
// Write batches
let write_task = async move {
let mut stream = input;
+ let mut total_rows = 0i64;
while let Some(batch_result) = stream.try_next().await.transpose()
{
let batch = batch_result?;
+ // Track row count
+ total_rows += batch.num_rows() as i64;
+
// Rename columns in the batch to match output schema
let renamed_batch = if !column_names.is_empty() {
- use arrow::record_batch::RecordBatch;
RecordBatch::try_new(Arc::clone(&schema_for_write),
batch.columns().to_vec())
.map_err(|e| {
DataFusionError::Execution(format!(
@@ -264,14 +295,29 @@ impl ExecutionPlan for ParquetWriterExec {
DataFusionError::Execution(format!("Failed to close writer:
{}", e))
})?;
+ // Get file size
+ let file_size = std::fs::metadata(&part_file)
+ .map(|m| m.len() as i64)
+ .unwrap_or(0);
+
+ // Update metrics with write statistics
+ files_written.add(1);
+ bytes_written.add(file_size as usize);
+ rows_written.add(total_rows as usize);
+
+ // Log metadata for debugging
+ eprintln!(
+ "Wrote Parquet file: path={}, size={}, rows={}",
+ part_file, file_size, total_rows
+ );
+
// Return empty stream to indicate completion
Ok::<_, DataFusionError>(futures::stream::empty())
};
- // Execute the write task and convert to a stream
- use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+ // Execute the write task and create a stream that does not return any
batches
Ok(Box::pin(RecordBatchStreamAdapter::new(
- output_schema,
+ self.schema(),
futures::stream::once(write_task).try_flatten(),
)))
}
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index b0746a6f8..ccbd0b250 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1466,6 +1466,13 @@ impl PhysicalPlanner {
let parquet_writer = Arc::new(ParquetWriterExec::try_new(
Arc::clone(&child.native_plan),
writer.output_path.clone(),
+ writer
+ .work_dir
+ .as_ref()
+ .expect("work_dir is provided")
+ .clone(),
+ writer.job_id.clone(),
+ writer.task_attempt_id,
codec,
self.partition,
writer.column_names.clone(),
diff --git a/native/proto/src/proto/operator.proto
b/native/proto/src/proto/operator.proto
index a95832709..f09695b7c 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -241,6 +241,13 @@ message ParquetWriter {
string output_path = 1;
CompressionCodec compression = 2;
repeated string column_names = 4;
+ // Working directory for temporary files (used by FileCommitProtocol)
+ // If not set, files are written directly to output_path
+ optional string work_dir = 5;
+ // Job ID for tracking this write operation
+ optional string job_id = 6;
+ // Task attempt ID for this specific task
+ optional int32 task_attempt_id = 7;
}
enum AggregateMode {
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
index 46d01c887..7fdf05521 100644
---
a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala
@@ -23,6 +23,7 @@ import java.util.Locale
import scala.jdk.CollectionConverters._
+import org.apache.spark.SparkException
import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import
org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand,
WriteFilesExec}
@@ -129,6 +130,8 @@ object CometDataWritingCommand extends
CometOperatorSerde[DataWritingCommandExec
.setOutputPath(outputPath)
.setCompression(codec)
.addAllColumnNames(cmd.query.output.map(_.name).asJava)
+ // Note: work_dir, job_id, and task_attempt_id will be set at
execution time
+ // in CometNativeWriteExec, as they depend on the Spark task context
.build()
val writerOperator = Operator
@@ -163,7 +166,29 @@ object CometDataWritingCommand extends
CometOperatorSerde[DataWritingCommandExec
other
}
- CometNativeWriteExec(nativeOp, childPlan, outputPath)
+ // Create FileCommitProtocol for atomic writes
+ val jobId = java.util.UUID.randomUUID().toString
+ val committer =
+ try {
+ // Use Spark's SQLHadoopMapReduceCommitProtocol
+ val committerClass =
+
classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol]
+ val constructor =
+ committerClass.getConstructor(classOf[String], classOf[String],
classOf[Boolean])
+ Some(
+ constructor
+ .newInstance(
+ jobId,
+ outputPath,
+ java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for
now
+ )
+ .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol])
+ } catch {
+ case e: Exception =>
+ throw new SparkException(s"Could not instantiate FileCommitProtocol:
${e.getMessage}")
+ }
+
+ CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId)
}
private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = {
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
index 2617e8c60..f153a691e 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala
@@ -19,54 +19,122 @@
package org.apache.spark.sql.comet
+import java.io.ByteArrayOutputStream
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID,
TaskID, TaskType}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
import org.apache.comet.CometExecIterator
import org.apache.comet.serde.OperatorOuterClass.Operator
/**
- * Comet physical operator for native Parquet write operations.
+ * Comet physical operator for native Parquet write operations with
FileCommitProtocol support.
+ *
+ * This operator writes data to Parquet files using the native Comet engine.
It integrates with
+ * Spark's FileCommitProtocol to provide atomic writes with proper staging and
commit semantics.
*
- * This operator writes data to Parquet files using the native Comet engine.
It wraps the child
- * operator and adds a ParquetWriter operator on top.
+ * The implementation includes support for Spark's file commit protocol
through work_dir, job_id,
+ * and task_attempt_id parameters that can be set in the operator. When
work_dir is set, files are
+ * written to a temporary location that can be atomically committed later.
*
* @param nativeOp
- * The native operator representing the write operation
+ * The native operator representing the write operation (template, will be
modified per task)
* @param child
* The child operator providing the data to write
* @param outputPath
* The path where the Parquet file will be written
+ * @param committer
+ * FileCommitProtocol for atomic writes. If None, files are written directly.
+ * @param jobTrackerID
+ * Unique identifier for this write job
*/
-case class CometNativeWriteExec(nativeOp: Operator, child: SparkPlan,
outputPath: String)
+case class CometNativeWriteExec(
+ nativeOp: Operator,
+ child: SparkPlan,
+ outputPath: String,
+ committer: Option[FileCommitProtocol] = None,
+ jobTrackerID: String = Utils.createTempDir().getName)
extends CometNativeExec
with UnaryExecNode {
override def originalPlan: SparkPlan = child
+ // Accumulator to collect TaskCommitMessages from all tasks
+ // Must be eagerly initialized on driver, not lazy
+ @transient private val taskCommitMessagesAccum =
+
sparkContext.collectionAccumulator[FileCommitProtocol.TaskCommitMessage]("taskCommitMessages")
+
override def serializedPlanOpt: SerializedPlan = {
- val outputStream = new java.io.ByteArrayOutputStream()
+ val outputStream = new ByteArrayOutputStream()
nativeOp.writeTo(outputStream)
outputStream.close()
SerializedPlan(Some(outputStream.toByteArray))
}
- override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
override def nodeName: String = "CometNativeWrite"
+ override lazy val metrics: Map[String, SQLMetric] = Map(
+ "files_written" -> SQLMetrics.createMetric(sparkContext, "number of
written data files"),
+ "bytes_written" -> SQLMetrics.createSizeMetric(sparkContext, "written
data"),
+ "rows_written" -> SQLMetrics.createMetric(sparkContext, "number of written
rows"))
+
override def doExecute(): RDD[InternalRow] = {
- // Execute the native write
+ // Setup job if committer is present
+ committer.foreach { c =>
+ val jobContext = createJobContext()
+ c.setupJob(jobContext)
+ }
+
+ // Execute the native write with commit protocol
val resultRDD = doExecuteColumnar()
- // Convert to empty InternalRow RDD (write operations typically return
empty results)
- resultRDD.mapPartitions { iter =>
- // Consume all batches (they should be empty)
- iter.foreach(_.close())
- Iterator.empty
+
+ // Force execution by consuming all batches
+ resultRDD
+ .mapPartitions { iter =>
+ iter.foreach(_.close())
+ Iterator.empty
+ }
+ .count()
+
+ // Extract write statistics from metrics
+ val filesWritten = metrics("files_written").value
+ val bytesWritten = metrics("bytes_written").value
+ val rowsWritten = metrics("rows_written").value
+
+ // Collect TaskCommitMessages from accumulator
+ val commitMessages = taskCommitMessagesAccum.value.asScala.toSeq
+
+ // Commit job with collected TaskCommitMessages
+ committer.foreach { c =>
+ val jobContext = createJobContext()
+ try {
+ c.commitJob(jobContext, commitMessages)
+ logInfo(
+ s"Successfully committed write job to $outputPath: " +
+ s"$filesWritten files, $bytesWritten bytes, $rowsWritten rows")
+ } catch {
+ case e: Exception =>
+ logError("Failed to commit job, aborting", e)
+ c.abortJob(jobContext)
+ throw e
+ }
}
+
+ // Return empty RDD as write operations don't return data
+ sparkContext.emptyRDD[InternalRow]
}
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
@@ -86,27 +154,148 @@ case class CometNativeWriteExec(nativeOp: Operator,
child: SparkPlan, outputPath
// Capture metadata before the transformation
val numPartitions = childRDD.getNumPartitions
val numOutputCols = child.output.length
+ val capturedCommitter = committer
+ val capturedJobTrackerID = jobTrackerID
+ val capturedNativeOp = nativeOp
+ val capturedAccumulator = taskCommitMessagesAccum // Capture accumulator
for use in tasks
- // Execute native write operation
+ // Execute native write operation with task-level commit protocol
childRDD.mapPartitionsInternal { iter =>
+ val partitionId = org.apache.spark.TaskContext.getPartitionId()
+ val taskAttemptId = org.apache.spark.TaskContext.get().taskAttemptId()
+
+ // Setup task-level commit protocol if provided
+ val (workDir, taskContext, commitMsg) = capturedCommitter
+ .map { committer =>
+ val taskContext =
+ createTaskContext(capturedJobTrackerID, partitionId,
taskAttemptId.toInt)
+
+ // Setup task - this creates the temporary working directory
+ committer.setupTask(taskContext)
+
+ // Get the work directory for temp files
+ val workPath = committer.newTaskTempFile(taskContext, None, "")
+ val workDir = new Path(workPath).getParent.toString
+
+ (Some(workDir), Some((committer, taskContext)), null)
+ }
+ .getOrElse((None, None, null))
+
+ // Modify the native operator to include task-specific parameters
+ val modifiedNativeOp = if (workDir.isDefined) {
+ val parquetWriter = capturedNativeOp.getParquetWriter.toBuilder
+ .setWorkDir(workDir.get)
+ .setJobId(capturedJobTrackerID)
+ .setTaskAttemptId(taskAttemptId.toInt)
+ .build()
+
+ capturedNativeOp.toBuilder.setParquetWriter(parquetWriter).build()
+ } else {
+ capturedNativeOp
+ }
+
val nativeMetrics = CometMetricNode.fromCometPlan(this)
- val outputStream = new java.io.ByteArrayOutputStream()
- nativeOp.writeTo(outputStream)
+ val outputStream = new ByteArrayOutputStream()
+ modifiedNativeOp.writeTo(outputStream)
outputStream.close()
val planBytes = outputStream.toByteArray
- new CometExecIterator(
+ val execIterator = new CometExecIterator(
CometExec.newIterId,
Seq(iter),
numOutputCols,
planBytes,
nativeMetrics,
numPartitions,
- org.apache.spark.TaskContext.getPartitionId(),
+ partitionId,
None,
Seq.empty)
+ // Wrap the iterator to handle task commit/abort and capture
TaskCommitMessage
+ new Iterator[ColumnarBatch] {
+ private var completed = false
+ private var thrownException: Option[Throwable] = None
+
+ override def hasNext: Boolean = {
+ val result =
+ try {
+ execIterator.hasNext
+ } catch {
+ case e: Throwable =>
+ thrownException = Some(e)
+ handleTaskEnd()
+ throw e
+ }
+
+ if (!result && !completed) {
+ handleTaskEnd()
+ }
+
+ result
+ }
+
+ override def next(): ColumnarBatch = {
+ try {
+ execIterator.next()
+ } catch {
+ case e: Throwable =>
+ thrownException = Some(e)
+ handleTaskEnd()
+ throw e
+ }
+ }
+
+ private def handleTaskEnd(): Unit = {
+ if (!completed) {
+ completed = true
+
+ // Handle commit or abort based on whether an exception was thrown
+ taskContext.foreach { case (committer, ctx) =>
+ try {
+ if (thrownException.isEmpty) {
+ // Commit the task and add message to accumulator
+ val message = committer.commitTask(ctx)
+ capturedAccumulator.add(message)
+ logInfo(s"Task ${ctx.getTaskAttemptID} committed
successfully")
+ } else {
+ // Abort the task
+ committer.abortTask(ctx)
+ val exMsg = thrownException.get.getMessage
+ logWarning(s"Task ${ctx.getTaskAttemptID} aborted due to
exception: $exMsg")
+ }
+ } catch {
+ case e: Exception =>
+ // Log the commit/abort exception but don't mask the
original exception
+ logError(s"Error during task commit/abort: ${e.getMessage}",
e)
+ if (thrownException.isEmpty) {
+ // If no original exception, propagate the commit/abort
exception
+ throw e
+ }
+ }
+ }
+ }
+ }
+ }
}
}
+
+ /** Create a JobContext for the write job */
+ private def createJobContext(): Job = {
+ val job = Job.getInstance()
+ job.setJobID(new org.apache.hadoop.mapreduce.JobID(jobTrackerID, 0))
+ job
+ }
+
+ /** Create a TaskAttemptContext for a specific task */
+ private def createTaskContext(
+ jobId: String,
+ partitionId: Int,
+ attemptNumber: Int): TaskAttemptContext = {
+ val job = Job.getInstance()
+ val taskAttemptID = new TaskAttemptID(
+ new TaskID(new org.apache.hadoop.mapreduce.JobID(jobId, 0),
TaskType.REDUCE, partitionId),
+ attemptNumber)
+ new TaskAttemptContextImpl(job.getConfiguration, taskAttemptID)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]