This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 638078f4ea44 [SPARK-46520][PYTHON] Support overwrite mode for Python data source write 638078f4ea44 is described below commit 638078f4ea445c67a4ca78c45ad8fc4a6c9b399b Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Wed Dec 27 20:44:42 2023 +0900 [SPARK-46520][PYTHON] Support overwrite mode for Python data source write ### What changes were proposed in this pull request? This PR supports the `overwrite` mode for Python data sources write in addition to the append mode. ### Why are the changes needed? To support Python data sources. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #44505 from allisonwang-db/spark-46520-overwrite. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../pyspark/sql/worker/write_into_data_source.py | 7 +-- .../python/UserDefinedPythonDataSource.scala | 30 +++++++----- .../execution/python/PythonDataSourceSuite.scala | 56 +++++++++++++++++++++- 3 files changed, 77 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 9c311dad0334..eea4a75e3be4 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -24,6 +24,7 @@ from pyspark.sql.connect.conversion import ArrowTableToRowsConversion from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( + read_bool, read_int, write_int, SpecialLengths, @@ -148,8 +149,8 @@ def main(infile: IO, outfile: IO) -> None: value = utf8_deserializer.loads(infile) options[key] = value - # Receive the save mode. - save_mode = utf8_deserializer.loads(infile) + # Receive the `overwrite` flag. + overwrite = read_bool(infile) # Instantiate a data source. try: @@ -162,7 +163,7 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate the data source writer. try: - writer = data_source.writer(schema, save_mode) + writer = data_source.writer(schema, overwrite) except Exception as e: raise PySparkRuntimeError( error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 00974a7e2974..778f55595aee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -26,17 +26,17 @@ import net.razorvine.pickle.Pickler import org.apache.spark.{JobArtifactSet, SparkException} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} -import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, Write, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -77,7 +77,7 @@ class PythonTableProvider extends TableProvider { override def name(): String = shortName override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ, BATCH_WRITE) + BATCH_READ, BATCH_WRITE, TRUNCATE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new ScanBuilder with Batch with Scan { @@ -115,7 +115,15 @@ class PythonTableProvider extends TableProvider { override def schema(): StructType = outputSchema override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { + new WriteBuilder with SupportsTruncate { + + private var isTruncate = false + + override def truncate(): WriteBuilder = { + isTruncate = true + this + } + override def build(): Write = new Write { override def toBatch: BatchWrite = new BatchWrite { @@ -127,7 +135,7 @@ class PythonTableProvider extends TableProvider { shortName, info.schema(), info.options(), - SaveMode.Append) + isTruncate) PythonBatchWriterFactory(source, writeInfo.func, info.schema(), jobArtifactUUID) } @@ -316,13 +324,13 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, inputSchema: StructType, options: CaseInsensitiveStringMap, - mode: SaveMode): PythonDataSourceWriteInfo = { + overwrite: Boolean): PythonDataSourceWriteInfo = { new UserDefinedPythonDataSourceWriteRunner( dataSourceCls, provider, inputSchema, options.asCaseSensitiveMap().asScala.toMap, - mode).runInPython() + overwrite).runInPython() } /** @@ -540,7 +548,7 @@ class UserDefinedPythonDataSourceWriteRunner( provider: String, inputSchema: StructType, options: Map[String, String], - mode: SaveMode) extends PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) { + overwrite: Boolean) extends PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) { override val workerModule: String = "pyspark.sql.worker.write_into_data_source" @@ -564,8 +572,8 @@ class UserDefinedPythonDataSourceWriteRunner( PythonWorkerUtils.writeUTF(value, dataOut) } - // Send the mode - PythonWorkerUtils.writeUTF(mode.toString, dataOut) + // Send the `overwrite` flag + dataOut.writeBoolean(overwrite) } override protected def receiveFromPython( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index b04569ae5541..c261f1d529fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -462,6 +462,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSourceScript = s""" |import json + |import os |from pyspark import TaskContext |from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage | @@ -474,7 +475,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | partition_id = context.partitionId() | path = self.options.get("path") | assert path is not None - | output_path = f"{path}/{partition_id}.json" + | output_path = os.path.join(path, f"{partition_id}.json") | cnt = 0 | with open(output_path, "w") as file: | for row in iterator: @@ -483,7 +484,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return WriterCommitMessage() | |class SimpleDataSource(DataSource): - | def writer(self, schema, saveMode): + | def writer(self, schema, overwrite): | return SimpleDataSourceWriter(self.options) |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) @@ -548,4 +549,55 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { "cannot be written with ErrorIfExists mode, please use Append or Overwrite modes instead.")) } } + + test("data source write - overwrite mode") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |import json + |import os + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def __init__(self, options, overwrite): + | self.options = options + | self.overwrite = overwrite + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | path = self.options.get("path") + | assert path is not None + | output_path = os.path.join(path, f"{partition_id}.json") + | cnt = 0 + | mode = "w" if self.overwrite else "a" + | with open(output_path, mode) as file: + | for row in iterator: + | file.write(json.dumps(row.asDict()) + "\\n") + | cnt += 1 + | return WriterCommitMessage() + | + |class SimpleDataSource(DataSource): + | def writer(self, schema, overwrite): + | return SimpleDataSourceWriter(self.options, overwrite) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + withTempDir { dir => + val path = dir.getAbsolutePath + spark.range(1).write.format(dataSourceName).mode("append").save(path) + checkAnswer( + spark.read.json(path), + Seq(Row(0))) + spark.range(1).write.format(dataSourceName).mode("append").save(path) + checkAnswer( + spark.read.json(path), + Seq(Row(0), Row(0))) + spark.range(2, 3).write.format(dataSourceName).mode("overwrite").save(path) + checkAnswer( + spark.read.json(path), + Seq(Row(2))) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org