This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 14d347c6c85 [SPARK-45654][PYTHON] Add Python data source write API 14d347c6c85 is described below commit 14d347c6c85172cee904bd3131392839594d5f2f Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Wed Nov 1 09:14:23 2023 +0900 [SPARK-45654][PYTHON] Add Python data source write API ### What changes were proposed in this pull request? This PR adds Python data source write API and `DataSourceWriter` class `datasource.py`. Here is an overview of writer class: ```python class DataSourceWriter(ABC): abstractmethod def write(self, iterator: Iterator[Row]) -> Any: ... def commit(self, messages: List[Any]) -> None: ... def abort(self, messages: List[Any]) -> None: ... ``` ### Why are the changes needed? To support Python data source write. ### Does this PR introduce _any_ user-facing change? No. This PR alone does not introduce any user-facing change. ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43516 from allisonwang-db/spark-45654-write-api. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/datasource.py | 99 +++++++++++++++++++++- python/pyspark/sql/tests/test_python_datasource.py | 2 + 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 5cda6596b3f..c30a2c8689d 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -26,9 +26,10 @@ if TYPE_CHECKING: from pyspark.sql.session import SparkSession -__all__ = ["DataSource", "DataSourceReader", "DataSourceRegistration"] +__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", "DataSourceRegistration"] +@since(4.0) class DataSource(ABC): """ A base class for data sources. @@ -133,7 +134,29 @@ class DataSource(ABC): """ raise NotImplementedError + def writer(self, schema: StructType, saveMode: str) -> "DataSourceWriter": + """ + Returns a ``DataSourceWriter`` instance for writing data. + + The implementation is required for writable data sources. + + Parameters + ---------- + schema : StructType + The schema of the data to be written. + saveMode : str + A string identifies the save mode. It can be one of the following: + `append`, `overwrite`, `error`, `ignore`. + + Returns + ------- + writer : DataSourceWriter + A writer instance for this data source. + """ + raise NotImplementedError + +@since(4.0) class DataSourceReader(ABC): """ A base class for data source readers. Data source readers are responsible for @@ -229,6 +252,80 @@ class DataSourceReader(ABC): ... +@since(4.0) +class DataSourceWriter(ABC): + """ + A base class for data source writers. Data source writers are responsible for saving + the data to the data source. + """ + + @abstractmethod + def write(self, iterator: Iterator[Row]) -> "WriterCommitMessage": + """ + Writes data into the data source. + + This method is called once on each executor to write data to the data source. + It accepts an iterator of input data and returns a single row representing a + commit message, or None if there is no commit message. + + The driver collects commit messages, if any, from all executors and passes them + to the ``commit`` method if all tasks run successfully. If any task fails, the + ``abort`` method will be called with the collected commit messages. + + Parameters + ---------- + iterator : Iterator[Row] + An iterator of input data. + + Returns + ------- + WriterCommitMessage : a serializable commit message + """ + ... + + def commit(self, messages: List["WriterCommitMessage"]) -> None: + """ + Commits this writing job with a list of commit messages. + + This method is invoked on the driver when all tasks run successfully. The + commit messages are collected from the ``write`` method call from each task, + and are passed to this method. The implementation should use the commit messages + to commit the writing job to the data source. + + Parameters + ---------- + messages : List[WriterCommitMessage] + A list of commit messages. + """ + ... + + def abort(self, messages: List["WriterCommitMessage"]) -> None: + """ + Aborts this writing job due to task failures. + + This method is invoked on the driver when one or more tasks failed. The commit + messages are collected from the ``write`` method call from each task, and are + passed to this method. The implementation should use the commit messages to + abort the writing job to the data source. + + Parameters + ---------- + messages : List[WriterCommitMessage] + A list of commit messages. + """ + ... + + +@since(4.0) +class WriterCommitMessage: + """ + A commit message returned by the ``write`` method of ``DataSourceWriter`` and will be + sent back to the driver side as input parameter of ``commit`` or ``abort`` method. + """ + + ... + + @since(4.0) class DataSourceRegistration: """ diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 6584312dda4..b429d73fb7d 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -33,6 +33,8 @@ class BasePythonDataSourceTestsMixin: ds.schema() with self.assertRaises(NotImplementedError): ds.reader(None) + with self.assertRaises(NotImplementedError): + ds.writer(None, None) def test_basic_data_source_reader_class(self): class MyDataSourceReader(DataSourceReader): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org