This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 14d347c6c85 [SPARK-45654][PYTHON] Add Python data source write API
14d347c6c85 is described below

commit 14d347c6c85172cee904bd3131392839594d5f2f
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Wed Nov 1 09:14:23 2023 +0900

    [SPARK-45654][PYTHON] Add Python data source write API
    
    ### What changes were proposed in this pull request?
    
    This PR adds Python data source write API and `DataSourceWriter` class 
`datasource.py`.
    
    Here is an overview of writer class:
    
    ```python
    class DataSourceWriter(ABC):
        abstractmethod
        def write(self, iterator: Iterator[Row]) -> Any:
            ...
    
        def commit(self, messages: List[Any]) -> None:
            ...
    
        def abort(self, messages: List[Any]) -> None:
            ...
    ```
    
    ### Why are the changes needed?
    
    To support Python data source write.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This PR alone does not introduce any user-facing change.
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43516 from allisonwang-db/spark-45654-write-api.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/datasource.py                   | 99 +++++++++++++++++++++-
 python/pyspark/sql/tests/test_python_datasource.py |  2 +
 2 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 5cda6596b3f..c30a2c8689d 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -26,9 +26,10 @@ if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
 
 
-__all__ = ["DataSource", "DataSourceReader", "DataSourceRegistration"]
+__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", 
"DataSourceRegistration"]
 
 
+@since(4.0)
 class DataSource(ABC):
     """
     A base class for data sources.
@@ -133,7 +134,29 @@ class DataSource(ABC):
         """
         raise NotImplementedError
 
+    def writer(self, schema: StructType, saveMode: str) -> "DataSourceWriter":
+        """
+        Returns a ``DataSourceWriter`` instance for writing data.
+
+        The implementation is required for writable data sources.
+
+        Parameters
+        ----------
+        schema : StructType
+            The schema of the data to be written.
+        saveMode : str
+            A string identifies the save mode. It can be one of the following:
+            `append`, `overwrite`, `error`, `ignore`.
+
+        Returns
+        -------
+        writer : DataSourceWriter
+            A writer instance for this data source.
+        """
+        raise NotImplementedError
+
 
+@since(4.0)
 class DataSourceReader(ABC):
     """
     A base class for data source readers. Data source readers are responsible 
for
@@ -229,6 +252,80 @@ class DataSourceReader(ABC):
         ...
 
 
+@since(4.0)
+class DataSourceWriter(ABC):
+    """
+    A base class for data source writers. Data source writers are responsible 
for saving
+    the data to the data source.
+    """
+
+    @abstractmethod
+    def write(self, iterator: Iterator[Row]) -> "WriterCommitMessage":
+        """
+        Writes data into the data source.
+
+        This method is called once on each executor to write data to the data 
source.
+        It accepts an iterator of input data and returns a single row 
representing a
+        commit message, or None if there is no commit message.
+
+        The driver collects commit messages, if any, from all executors and 
passes them
+        to the ``commit`` method if all tasks run successfully. If any task 
fails, the
+        ``abort`` method will be called with the collected commit messages.
+
+        Parameters
+        ----------
+        iterator : Iterator[Row]
+            An iterator of input data.
+
+        Returns
+        -------
+        WriterCommitMessage : a serializable commit message
+        """
+        ...
+
+    def commit(self, messages: List["WriterCommitMessage"]) -> None:
+        """
+        Commits this writing job with a list of commit messages.
+
+        This method is invoked on the driver when all tasks run successfully. 
The
+        commit messages are collected from the ``write`` method call from each 
task,
+        and are passed to this method. The implementation should use the 
commit messages
+        to commit the writing job to the data source.
+
+        Parameters
+        ----------
+        messages : List[WriterCommitMessage]
+            A list of commit messages.
+        """
+        ...
+
+    def abort(self, messages: List["WriterCommitMessage"]) -> None:
+        """
+        Aborts this writing job due to task failures.
+
+        This method is invoked on the driver when one or more tasks failed. 
The commit
+        messages are collected from the ``write`` method call from each task, 
and are
+        passed to this method. The implementation should use the commit 
messages to
+        abort the writing job to the data source.
+
+        Parameters
+        ----------
+        messages : List[WriterCommitMessage]
+            A list of commit messages.
+        """
+        ...
+
+
+@since(4.0)
+class WriterCommitMessage:
+    """
+    A commit message returned by the ``write`` method of ``DataSourceWriter`` 
and will be
+    sent back to the driver side as input parameter of ``commit`` or ``abort`` 
method.
+    """
+
+    ...
+
+
 @since(4.0)
 class DataSourceRegistration:
     """
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 6584312dda4..b429d73fb7d 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -33,6 +33,8 @@ class BasePythonDataSourceTestsMixin:
             ds.schema()
         with self.assertRaises(NotImplementedError):
             ds.reader(None)
+        with self.assertRaises(NotImplementedError):
+            ds.writer(None, None)
 
     def test_basic_data_source_reader_class(self):
         class MyDataSourceReader(DataSourceReader):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to