This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ecca1bf6453e [SPARK-47365][PYTHON] Add toArrow() DataFrame method to 
PySpark
ecca1bf6453e is described below

commit ecca1bf6453e5e0042e1b56d4c35fb0b4d0f3121
Author: Ian Cook <ianmc...@gmail.com>
AuthorDate: Thu May 9 17:25:34 2024 +0900

    [SPARK-47365][PYTHON] Add toArrow() DataFrame method to PySpark
    
    ### What changes were proposed in this pull request?
    - Add a PySpark DataFrame method `toArrow()` which returns the contents of 
the DataFrame as a [PyArrow 
Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html), for 
both local Spark and Spark Connect.
    - Add a new entry to the **Apache Arrow in PySpark** user guide page 
describing usage of the `toArrow()` method.
    - Add  a new option to the method `_collect_as_arrow()` to provide more 
useful output when there are zero records returned. (This keeps the 
implementation of `toArrow()` simpler.)
    
    ### Why are the changes needed?
    In the Apache Arrow community, we hear from a lot of users who want to 
return the contents of a PySpark DataFrame as a PyArrow Table. Currently the 
only documented way to do this is to return the contents as a pandas DataFrame, 
then use PyArrow (`pa`) to convert that to a PyArrow Table.
    ```py
    pa.Table.from_pandas(df.toPandas())
    ```
    But going through pandas adds significant overhead which is easily avoided 
since internally `toPandas()` already converts the contents of Spark DataFrame 
to Arrow format as an intermediate step when 
`spark.sql.execution.arrow.pyspark.enabled` is `true`.
    
    Currently it is also possible to use the experimental `_collect_as_arrow()` 
method to return the contents of a PySpark DataFrame as a list of PyArrow 
RecordBatches. This PR adds a new non-experimental method `toArrow()` which 
returns the more user-friendly PyArrow Table object.
    
    This PR also adds a new argument `empty_list_if_zero_records` to the 
experimental method `_collect_as_arrow()` to control what the method returns in 
the case when the result data has zero rows. If set to `True` (the default), 
the existing behavior is preserved, and the method returns an empty Python 
list. If set to `False`, the method returns returns a length-one list 
containing an empty Arrow RecordBatch which includes the schema. This is used 
by `toArrow()` which requires the schema [...]
    
    For Spark Connect, there is already a `SparkSession.client.to_table()` 
method that returns a PyArrow table. This PR uses that to expose `toArrow()` 
for Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    - It adds a DataFrame method `toArrow()` to the PySpark SQL DataFrame API.
    - It adds a new argument `empty_list_if_zero_records` to the experimental 
DataFrame method `_collect_as_arrow()` with a default value which preserves the 
method's existing behavior.
    - It exposes `toArrow()` for Spark Connect, via the existing 
`SparkSession.client.to_table()` method.
    - It does not introduce any other user-facing changes.
    
    ### How was this patch tested?
    This adds a new test and a new helper function for the test in 
`pyspark/sql/tests/test_arrow.py`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45481 from ianmcook/SPARK-47365.
    
    Lead-authored-by: Ian Cook <ianmc...@gmail.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 examples/src/main/python/sql/arrow.py              | 18 ++++++++
 .../source/reference/pyspark.sql/dataframe.rst     |  1 +
 python/docs/source/user_guide/sql/arrow_pandas.rst | 49 +++++++++++++++-------
 python/pyspark/sql/classic/dataframe.py            |  4 ++
 python/pyspark/sql/connect/dataframe.py            |  4 ++
 python/pyspark/sql/dataframe.py                    | 30 +++++++++++++
 python/pyspark/sql/pandas/conversion.py            | 48 +++++++++++++++++++--
 python/pyspark/sql/tests/test_arrow.py             | 35 ++++++++++++++++
 8 files changed, 169 insertions(+), 20 deletions(-)

diff --git a/examples/src/main/python/sql/arrow.py 
b/examples/src/main/python/sql/arrow.py
index 03daf18eadbf..48aee48d929c 100644
--- a/examples/src/main/python/sql/arrow.py
+++ b/examples/src/main/python/sql/arrow.py
@@ -33,6 +33,22 @@ require_minimum_pandas_version()
 require_minimum_pyarrow_version()
 
 
+def dataframe_to_arrow_table_example(spark: SparkSession) -> None:
+    import pyarrow as pa  # noqa: F401
+    from pyspark.sql.functions import rand
+
+    # Create a Spark DataFrame
+    df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), 
"2": rand()})
+
+    # Convert the Spark DataFrame to a PyArrow Table
+    table = df.select("*").toArrow()
+
+    print(table.schema)
+    # 0: double not null
+    # 1: double not null
+    # 2: double not null
+
+
 def dataframe_with_arrow_example(spark: SparkSession) -> None:
     import numpy as np
     import pandas as pd
@@ -302,6 +318,8 @@ if __name__ == "__main__":
         .appName("Python Arrow-in-Spark example") \
         .getOrCreate()
 
+    print("Running Arrow conversion example: DataFrame to Table")
+    dataframe_to_arrow_table_example(spark)
     print("Running Pandas to/from conversion example")
     dataframe_with_arrow_example(spark)
     print("Running pandas_udf example: Series to Frame")
diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst 
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index b69a2771b04f..ec39b645b140 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -109,6 +109,7 @@ DataFrame
     DataFrame.tail
     DataFrame.take
     DataFrame.to
+    DataFrame.toArrow
     DataFrame.toDF
     DataFrame.toJSON
     DataFrame.toLocalIterator
diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst 
b/python/docs/source/user_guide/sql/arrow_pandas.rst
index 1d6a4df60690..0a527d832e21 100644
--- a/python/docs/source/user_guide/sql/arrow_pandas.rst
+++ b/python/docs/source/user_guide/sql/arrow_pandas.rst
@@ -39,6 +39,20 @@ is installed and available on all cluster nodes.
 You can install it using pip or conda from the conda-forge channel. See PyArrow
 `installation <https://arrow.apache.org/docs/python/install.html>`_ for 
details.
 
+Conversion to Arrow Table
+-------------------------
+
+You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a 
PyArrow Table.
+
+.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
+    :language: python
+    :lines: 37-49
+    :dedent: 4
+
+Note that :meth:`DataFrame.toArrow` results in the collection of all records 
in the DataFrame to
+the driver program and should be done on a small subset of the data. Not all 
Spark data types are
+currently supported and an error can be raised if a column has an unsupported 
type.
+
 Enabling for Conversion to/from Pandas
 --------------------------------------
 
@@ -53,7 +67,7 @@ This can be controlled by 
``spark.sql.execution.arrow.pyspark.fallback.enabled``
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 37-52
+    :lines: 53-68
     :dedent: 4
 
 Using the above optimizations with Arrow will produce the same results as when 
Arrow is not
@@ -90,7 +104,7 @@ specify the type hints of ``pandas.Series`` and 
``pandas.DataFrame`` as below:
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 56-80
+    :lines: 72-96
     :dedent: 4
 
 In the following sections, it describes the combinations of the supported type 
hints. For simplicity,
@@ -113,7 +127,7 @@ The following example shows how to create this Pandas UDF 
that computes the prod
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 84-114
+    :lines: 100-130
     :dedent: 4
 
 For detailed usage, please see :func:`pandas_udf`.
@@ -152,7 +166,7 @@ The following example shows how to create this Pandas UDF:
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 118-140
+    :lines: 134-156
     :dedent: 4
 
 For detailed usage, please see :func:`pandas_udf`.
@@ -174,7 +188,7 @@ The following example shows how to create this Pandas UDF:
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 144-167
+    :lines: 160-183
     :dedent: 4
 
 For detailed usage, please see :func:`pandas_udf`.
@@ -205,7 +219,7 @@ and window operations:
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 171-212
+    :lines: 187-228
     :dedent: 4
 
 .. currentmodule:: pyspark.sql.functions
@@ -270,7 +284,7 @@ in the group.
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 216-234
+    :lines: 232-250
     :dedent: 4
 
 For detailed usage, please see  please see :meth:`GroupedData.applyInPandas`
@@ -288,7 +302,7 @@ The following example shows how to use 
:meth:`DataFrame.mapInPandas`:
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 238-249
+    :lines: 254-265
     :dedent: 4
 
 For detailed usage, please see :meth:`DataFrame.mapInPandas`.
@@ -327,7 +341,7 @@ The following example shows how to use 
``DataFrame.groupby().cogroup().applyInPa
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 253-275
+    :lines: 269-291
     :dedent: 4
 
 
@@ -349,7 +363,7 @@ Here's an example that demonstrates the usage of both a 
default, pickled Python
 
 .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
     :language: python
-    :lines: 279-297
+    :lines: 295-313
     :dedent: 4
 
 Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more 
coherent type coercion mechanism. UDF
@@ -421,9 +435,12 @@ be verified by the user.
 Setting Arrow ``self_destruct`` for memory savings
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-Since Spark 3.2, the Spark configuration 
``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to 
enable PyArrow's ``self_destruct`` feature, which can save memory when creating 
a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while 
building the Pandas DataFrame.
-This option is experimental, and some operations may fail on the resulting 
Pandas DataFrame due to immutable backing arrays.
-Typically, you would see the error ``ValueError: buffer source array is 
read-only``.
-Newer versions of Pandas may fix these errors by improving support for such 
cases.
-You can work around this error by copying the column(s) beforehand.
-Additionally, this conversion may be slower because it is single-threaded.
+Since Spark 3.2, the Spark configuration 
``spark.sql.execution.arrow.pyspark.selfDestruct.enabled``
+can be used to enable PyArrow's ``self_destruct`` feature, which can save 
memory when creating a
+Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while 
building the Pandas
+DataFrame. This option can also save memory when creating a PyArrow Table via 
``toArrow``.
+This option is experimental. When used with ``toPandas``, some operations may 
fail on the resulting
+Pandas DataFrame due to immutable backing arrays. Typically, you would see the 
error
+``ValueError: buffer source array is read-only``. Newer versions of Pandas may 
fix these errors by
+improving support for such cases. You can work around this error by copying 
the column(s)
+beforehand. Additionally, this conversion may be slower because it is 
single-threaded.
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index db9f22517dda..9b6790d29aaa 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -74,6 +74,7 @@ from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
 if TYPE_CHECKING:
     from py4j.java_gateway import JavaObject
+    import pyarrow as pa
     from pyspark.core.rdd import RDD
     from pyspark.core.context import SparkContext
     from pyspark._typing import PrimitiveType
@@ -1825,6 +1826,9 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
     ) -> ParentDataFrame:
         return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, 
profile)
 
+    def toArrow(self) -> "pa.Table":
+        return PandasConversionMixin.toArrow(self)
+
     def toPandas(self) -> "PandasDataFrameLike":
         return PandasConversionMixin.toPandas(self)
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 843c92a9b27d..3c9415adec2d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1768,6 +1768,10 @@ class DataFrame(ParentDataFrame):
         assert table is not None
         return (table, schema)
 
+    def toArrow(self) -> "pa.Table":
+        table, _ = self._to_table()
+        return table
+
     def toPandas(self) -> "PandasDataFrameLike":
         query = self._plan.to_proto(self._session.client)
         return self._session.client.to_pandas(query, self._plan.observations)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e3d52c45d0c1..886f72cc371e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -44,6 +44,7 @@ from pyspark.sql.utils import dispatch_df_method
 
 if TYPE_CHECKING:
     from py4j.java_gateway import JavaObject
+    import pyarrow as pa
     from pyspark.core.context import SparkContext
     from pyspark.core.rdd import RDD
     from pyspark._typing import PrimitiveType
@@ -1200,6 +1201,7 @@ class DataFrame:
         DataFrame.take : Returns the first `n` rows.
         DataFrame.head : Returns the first `n` rows.
         DataFrame.toPandas : Returns the data as a pandas DataFrame.
+        DataFrame.toArrow : Returns the data as a PyArrow Table.
 
         Notes
         -----
@@ -6213,6 +6215,34 @@ class DataFrame:
         """
         ...
 
+    @dispatch_df_method
+    def toArrow(self) -> "pa.Table":
+        """
+        Returns the contents of this :class:`DataFrame` as PyArrow 
``pyarrow.Table``.
+
+        This is only available if PyArrow is installed and available.
+
+        .. versionadded:: 4.0.0
+
+        Notes
+        -----
+        This method should only be used if the resulting PyArrow 
``pyarrow.Table`` is
+        expected to be small, as all the data is loaded into the driver's 
memory.
+
+        This API is a developer API.
+
+        Examples
+        --------
+        >>> df.toArrow()  # doctest: +SKIP
+        pyarrow.Table
+        age: int64
+        name: string
+        ----
+        age: [[2,5]]
+        name: [["Alice","Bob"]]
+        """
+        ...
+
     def toPandas(self) -> "PandasDataFrameLike":
         """
         Returns the contents of this :class:`DataFrame` as Pandas 
``pandas.DataFrame``.
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index ec4e21daba97..344608317beb 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -225,15 +225,48 @@ class PandasConversionMixin:
         else:
             return pdf
 
-    def _collect_as_arrow(self, split_batches: bool = False) -> 
List["pa.RecordBatch"]:
+    def toArrow(self) -> "pa.Table":
+        from pyspark.sql.dataframe import DataFrame
+
+        assert isinstance(self, DataFrame)
+
+        jconf = self.sparkSession._jconf
+
+        from pyspark.sql.pandas.types import to_arrow_schema
+        from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
+
+        require_minimum_pyarrow_version()
+        to_arrow_schema(self.schema)
+
+        import pyarrow as pa
+
+        self_destruct = jconf.arrowPySparkSelfDestructEnabled()
+        batches = self._collect_as_arrow(
+            split_batches=self_destruct, empty_list_if_zero_records=False
+        )
+        table = pa.Table.from_batches(batches)
+        # Ensure only the table has a reference to the batches, so that
+        # self_destruct (if enabled) is effective
+        del batches
+        return table
+
+    def _collect_as_arrow(
+        self,
+        split_batches: bool = False,
+        empty_list_if_zero_records: bool = True,
+    ) -> List["pa.RecordBatch"]:
         """
-        Returns all records as a list of ArrowRecordBatches, pyarrow must be 
installed
+        Returns all records as a list of Arrow RecordBatches. PyArrow must be 
installed
         and available on driver and worker Python environments.
         This is an experimental feature.
 
         :param split_batches: split batches such that each column is in its 
own allocation, so
             that the selfDestruct optimization is effective; default False.
 
+        :param empty_list_if_zero_records: If True (the default), returns an 
empty list if the
+            result has 0 records. Otherwise, returns a list of length 1 
containing an empty
+            Arrow RecordBatch which includes the schema.
+
         .. note:: Experimental.
         """
         from pyspark.sql.dataframe import DataFrame
@@ -282,8 +315,15 @@ class PandasConversionMixin:
         batches = results[:-1]
         batch_order = results[-1]
 
-        # Re-order the batch list using the correct order
-        return [batches[i] for i in batch_order]
+        if len(batches) or empty_list_if_zero_records:
+            # Re-order the batch list using the correct order
+            return [batches[i] for i in batch_order]
+        else:
+            from pyspark.sql.pandas.types import to_arrow_schema
+
+            schema = to_arrow_schema(self.schema)
+            empty_arrays = [pa.array([], type=field.type) for field in schema]
+            return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]
 
 
 class SparkConversionMixin:
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 8636e953aaf8..71d3c46e5ee1 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -179,6 +179,35 @@ class ArrowTestsMixin:
         data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
         return pd.DataFrame(data=data_dict)
 
+    def create_arrow_table(self):
+        import pyarrow as pa
+        import pyarrow.compute as pc
+
+        data_dict = {}
+        for j, name in enumerate(self.schema.names):
+            data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+        t = pa.Table.from_pydict(data_dict)
+        # convert these to Arrow types
+        new_schema = t.schema.set(
+            t.schema.get_field_index("2_int_t"), pa.field("2_int_t", 
pa.int32())
+        )
+        new_schema = new_schema.set(
+            new_schema.get_field_index("4_float_t"), pa.field("4_float_t", 
pa.float32())
+        )
+        new_schema = new_schema.set(
+            new_schema.get_field_index("6_decimal_t"),
+            pa.field("6_decimal_t", pa.decimal128(38, 18)),
+        )
+        t = t.cast(new_schema)
+        # convert timestamp to local timezone
+        timezone = self.spark.conf.get("spark.sql.session.timeZone")
+        t = t.set_column(
+            t.schema.get_field_index("8_timestamp_t"),
+            "8_timestamp_t",
+            pc.assume_timezone(t["8_timestamp_t"], timezone),
+        )
+        return t
+
     @property
     def create_np_arrs(self):
         import numpy as np
@@ -339,6 +368,12 @@ class ArrowTestsMixin:
         pdf_arrow = df.toPandas()
         assert_frame_equal(pdf_arrow, pdf)
 
+    def test_arrow_round_trip(self):
+        t_in = self.create_arrow_table()
+        df = self.spark.createDataFrame(self.data, schema=self.schema)
+        t_out = df.toArrow()
+        self.assertTrue(t_out.equals(t_in))
+
     def test_pandas_self_destruct(self):
         import pyarrow as pa
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to