This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c9cfaac90fd4 [SPARK-46452][SQL] Add a new API in DataWriter to write an iterator of records c9cfaac90fd4 is described below commit c9cfaac90fd423c3a38e295234e24744b946cb02 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Wed Dec 20 19:17:21 2023 +0800 [SPARK-46452][SQL] Add a new API in DataWriter to write an iterator of records ### What changes were proposed in this pull request? This PR proposes to add a new method in `DataWriter` that supports writing an iterator of records: ```java void writeAll(Iterator<T> records) throws IOException ``` ### Why are the changes needed? To make the API more flexible and support more use cases (e.g Python data sources). See https://github.com/apache/spark/pull/43791 ### Does this PR introduce _any_ user-facing change? Yes. This PR introduces a new method in `DataWriter`. ### How was this patch tested? Existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44410 from allisonwang-db/spark-46452-dsv2-write-all. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/connector/write/DataWriter.java | 18 +++ .../datasources/v2/WriteToDataSourceV2Exec.scala | 121 ++++++++++++--------- 2 files changed, 88 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java index 6a1cee181bc2..d6e94fe2ca8b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.write; import java.io.Closeable; import java.io.IOException; +import java.util.Iterator; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.metric.CustomTaskMetric; @@ -74,6 +75,23 @@ public interface DataWriter<T> extends Closeable { */ void write(T record) throws IOException; + /** + * Writes all records provided by the given iterator. By default, it calls the {@link #write} + * method for each record in the iterator. + * <p> + * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. + * + * @since 4.0.0 + */ + default void writeAll(Iterator<T> records) throws IOException { + while (records.hasNext()) { + write(records.next()); + } + } + /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2527f201f3a8..97c1f7ced508 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -421,7 +421,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable { - protected def write(writer: W, row: InternalRow): Unit + protected def write(writer: W, iter: java.util.Iterator[InternalRow]): Unit def run( writerFactory: DataWriterFactory, @@ -436,19 +436,11 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial val attemptId = context.attemptNumber() val dataWriter = writerFactory.createWriter(partId, taskId).asInstanceOf[W] - var count = 0L + val iterWithMetrics = IteratorWithMetrics(iter, dataWriter, customMetrics) + // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - while (iter.hasNext) { - if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics( - dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) - } - - // Count is here. - count += 1 - write(dataWriter, iter.next()) - } + write(dataWriter, iterWithMetrics) CustomMetrics.updateMetrics( dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) @@ -476,7 +468,7 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId, " + s"stage $stageId.$stageAttempt)") - DataWritingSparkTaskResult(count, msg) + DataWritingSparkTaskResult(iterWithMetrics.count, msg) })(catchBlock = { // If there is an error, abort this writer @@ -489,11 +481,30 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial dataWriter.close() }) } + + private case class IteratorWithMetrics( + iter: Iterator[InternalRow], + dataWriter: W, + customMetrics: Map[String, SQLMetric]) extends java.util.Iterator[InternalRow] { + var count = 0L + + override def hasNext: Boolean = iter.hasNext + + override def next(): InternalRow = { + if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics( + dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) + } + count += 1 + iter.next() + } + } } object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] { - override protected def write(writer: DataWriter[InternalRow], row: InternalRow): Unit = { - writer.write(row) + override protected def write( + writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + writer.writeAll(iter) } } @@ -503,25 +514,29 @@ case class DeltaWritingSparkTask( private lazy val rowProjection = projections.rowProjection.orNull private lazy val rowIdProjection = projections.rowIdProjection - override protected def write(writer: DeltaWriter[InternalRow], row: InternalRow): Unit = { - val operation = row.getInt(0) + override protected def write( + writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + while (iter.hasNext) { + val row = iter.next() + val operation = row.getInt(0) - operation match { - case DELETE_OPERATION => - rowIdProjection.project(row) - writer.delete(null, rowIdProjection) + operation match { + case DELETE_OPERATION => + rowIdProjection.project(row) + writer.delete(null, rowIdProjection) - case UPDATE_OPERATION => - rowProjection.project(row) - rowIdProjection.project(row) - writer.update(null, rowIdProjection, rowProjection) + case UPDATE_OPERATION => + rowProjection.project(row) + rowIdProjection.project(row) + writer.update(null, rowIdProjection, rowProjection) - case INSERT_OPERATION => - rowProjection.project(row) - writer.insert(rowProjection) + case INSERT_OPERATION => + rowProjection.project(row) + writer.insert(rowProjection) - case other => - throw new SparkException(s"Unexpected operation ID: $other") + case other => + throw new SparkException(s"Unexpected operation ID: $other") + } } } } @@ -533,27 +548,31 @@ case class DeltaWithMetadataWritingSparkTask( private lazy val rowIdProjection = projections.rowIdProjection private lazy val metadataProjection = projections.metadataProjection.orNull - override protected def write(writer: DeltaWriter[InternalRow], row: InternalRow): Unit = { - val operation = row.getInt(0) - - operation match { - case DELETE_OPERATION => - rowIdProjection.project(row) - metadataProjection.project(row) - writer.delete(metadataProjection, rowIdProjection) - - case UPDATE_OPERATION => - rowProjection.project(row) - rowIdProjection.project(row) - metadataProjection.project(row) - writer.update(metadataProjection, rowIdProjection, rowProjection) - - case INSERT_OPERATION => - rowProjection.project(row) - writer.insert(rowProjection) - - case other => - throw new SparkException(s"Unexpected operation ID: $other") + override protected def write( + writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + while (iter.hasNext) { + val row = iter.next() + val operation = row.getInt(0) + + operation match { + case DELETE_OPERATION => + rowIdProjection.project(row) + metadataProjection.project(row) + writer.delete(metadataProjection, rowIdProjection) + + case UPDATE_OPERATION => + rowProjection.project(row) + rowIdProjection.project(row) + metadataProjection.project(row) + writer.update(metadataProjection, rowIdProjection, rowProjection) + + case INSERT_OPERATION => + rowProjection.project(row) + writer.insert(rowProjection) + + case other => + throw new SparkException(s"Unexpected operation ID: $other") + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org