This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 638078f4ea44 [SPARK-46520][PYTHON] Support overwrite mode for Python 
data source write
638078f4ea44 is described below

commit 638078f4ea445c67a4ca78c45ad8fc4a6c9b399b
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Wed Dec 27 20:44:42 2023 +0900

    [SPARK-46520][PYTHON] Support overwrite mode for Python data source write
    
    ### What changes were proposed in this pull request?
    
    This PR supports the `overwrite` mode for Python data sources write in 
addition to the append mode.
    
    ### Why are the changes needed?
    
    To support Python data sources.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44505 from allisonwang-db/spark-46520-overwrite.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../pyspark/sql/worker/write_into_data_source.py   |  7 +--
 .../python/UserDefinedPythonDataSource.scala       | 30 +++++++-----
 .../execution/python/PythonDataSourceSuite.scala   | 56 +++++++++++++++++++++-
 3 files changed, 77 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 9c311dad0334..eea4a75e3be4 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -24,6 +24,7 @@ from pyspark.sql.connect.conversion import 
ArrowTableToRowsConversion
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
+    read_bool,
     read_int,
     write_int,
     SpecialLengths,
@@ -148,8 +149,8 @@ def main(infile: IO, outfile: IO) -> None:
             value = utf8_deserializer.loads(infile)
             options[key] = value
 
-        # Receive the save mode.
-        save_mode = utf8_deserializer.loads(infile)
+        # Receive the `overwrite` flag.
+        overwrite = read_bool(infile)
 
         # Instantiate a data source.
         try:
@@ -162,7 +163,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Instantiate the data source writer.
         try:
-            writer = data_source.writer(schema, save_mode)
+            writer = data_source.writer(schema, overwrite)
         except Exception as e:
             raise PySparkRuntimeError(
                 error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 00974a7e2974..778f55595aee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -26,17 +26,17 @@ import net.razorvine.pickle.Pickler
 
 import org.apache.spark.{JobArtifactSet, SparkException}
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
-import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.PythonUDF
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability, TableProvider}
-import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, TRUNCATE}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
 import org.apache.spark.sql.connector.read.{Batch, InputPartition, 
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, 
DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, 
WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, 
DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, 
Write, WriteBuilder, WriterCommitMessage}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
@@ -77,7 +77,7 @@ class PythonTableProvider extends TableProvider {
       override def name(): String = shortName
 
       override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
-        BATCH_READ, BATCH_WRITE)
+        BATCH_READ, BATCH_WRITE, TRUNCATE)
 
       override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
         new ScanBuilder with Batch with Scan {
@@ -115,7 +115,15 @@ class PythonTableProvider extends TableProvider {
       override def schema(): StructType = outputSchema
 
       override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
-        new WriteBuilder {
+        new WriteBuilder with SupportsTruncate {
+
+          private var isTruncate = false
+
+          override def truncate(): WriteBuilder = {
+            isTruncate = true
+            this
+          }
+
           override def build(): Write = new Write {
 
             override def toBatch: BatchWrite = new BatchWrite {
@@ -127,7 +135,7 @@ class PythonTableProvider extends TableProvider {
                   shortName,
                   info.schema(),
                   info.options(),
-                  SaveMode.Append)
+                  isTruncate)
                 PythonBatchWriterFactory(source, writeInfo.func, 
info.schema(), jobArtifactUUID)
               }
 
@@ -316,13 +324,13 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       provider: String,
       inputSchema: StructType,
       options: CaseInsensitiveStringMap,
-      mode: SaveMode): PythonDataSourceWriteInfo = {
+      overwrite: Boolean): PythonDataSourceWriteInfo = {
     new UserDefinedPythonDataSourceWriteRunner(
       dataSourceCls,
       provider,
       inputSchema,
       options.asCaseSensitiveMap().asScala.toMap,
-      mode).runInPython()
+      overwrite).runInPython()
   }
 
   /**
@@ -540,7 +548,7 @@ class UserDefinedPythonDataSourceWriteRunner(
     provider: String,
     inputSchema: StructType,
     options: Map[String, String],
-    mode: SaveMode) extends 
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
+    overwrite: Boolean) extends 
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
 
   override val workerModule: String = 
"pyspark.sql.worker.write_into_data_source"
 
@@ -564,8 +572,8 @@ class UserDefinedPythonDataSourceWriteRunner(
       PythonWorkerUtils.writeUTF(value, dataOut)
     }
 
-    // Send the mode
-    PythonWorkerUtils.writeUTF(mode.toString, dataOut)
+    // Send the `overwrite` flag
+    dataOut.writeBoolean(overwrite)
   }
 
   override protected def receiveFromPython(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index b04569ae5541..c261f1d529fd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -462,6 +462,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val dataSourceScript =
       s"""
          |import json
+         |import os
          |from pyspark import TaskContext
          |from pyspark.sql.datasource import DataSource, DataSourceWriter, 
WriterCommitMessage
          |
@@ -474,7 +475,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        partition_id = context.partitionId()
          |        path = self.options.get("path")
          |        assert path is not None
-         |        output_path = f"{path}/{partition_id}.json"
+         |        output_path = os.path.join(path, f"{partition_id}.json")
          |        cnt = 0
          |        with open(output_path, "w") as file:
          |            for row in iterator:
@@ -483,7 +484,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return WriterCommitMessage()
          |
          |class SimpleDataSource(DataSource):
-         |    def writer(self, schema, saveMode):
+         |    def writer(self, schema, overwrite):
          |        return SimpleDataSourceWriter(self.options)
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
@@ -548,4 +549,55 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
         "cannot be written with ErrorIfExists mode, please use Append or 
Overwrite modes instead."))
     }
   }
+
+  test("data source write - overwrite mode") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |import json
+         |import os
+         |from pyspark import TaskContext
+         |from pyspark.sql.datasource import DataSource, DataSourceWriter, 
WriterCommitMessage
+         |
+         |class SimpleDataSourceWriter(DataSourceWriter):
+         |    def __init__(self, options, overwrite):
+         |        self.options = options
+         |        self.overwrite = overwrite
+         |
+         |    def write(self, iterator):
+         |        context = TaskContext.get()
+         |        partition_id = context.partitionId()
+         |        path = self.options.get("path")
+         |        assert path is not None
+         |        output_path = os.path.join(path, f"{partition_id}.json")
+         |        cnt = 0
+         |        mode = "w" if self.overwrite else "a"
+         |        with open(output_path, mode) as file:
+         |            for row in iterator:
+         |                file.write(json.dumps(row.asDict()) + "\\n")
+         |                cnt += 1
+         |        return WriterCommitMessage()
+         |
+         |class SimpleDataSource(DataSource):
+         |    def writer(self, schema, overwrite):
+         |        return SimpleDataSourceWriter(self.options, overwrite)
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      spark.range(1).write.format(dataSourceName).mode("append").save(path)
+      checkAnswer(
+        spark.read.json(path),
+        Seq(Row(0)))
+      spark.range(1).write.format(dataSourceName).mode("append").save(path)
+      checkAnswer(
+        spark.read.json(path),
+        Seq(Row(0), Row(0)))
+      spark.range(2, 
3).write.format(dataSourceName).mode("overwrite").save(path)
+      checkAnswer(
+        spark.read.json(path),
+        Seq(Row(2)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to