This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 731b89d5914 [SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON] 
Make `DataFrame.collect` handle None/NaN/Array/Binary porperly
731b89d5914 is described below

commit 731b89d59143adb8a4ab3d16dd9f0e08c799abf2
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Jan 5 08:52:08 2023 +0900

    [SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON] Make 
`DataFrame.collect` handle None/NaN/Array/Binary porperly
    
    ### What changes were proposed in this pull request?
    
    Existing `DataFrame.collect` directly collect coming Arrow batches into a 
Pandas DataFrame, and then convert each series into a Row, which is problematic 
since it can not correctly handle None/NaN/Arrays/Binary/etc.
    
    This PR refactor `DataFrame.collect` by directly building rows from the raw 
Arrow Table, in order to support:
    1, None/NaN values;
    2, ArrayType
    3, BinaryType
    
    ### Why are the changes needed?
    To be consistent with PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled doctests
    
    Closes #39386 from zhengruifeng/connect_fix_41833.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client.py    | 54 ++++++++++++++++++---------------
 python/pyspark/sql/connect/column.py    |  2 --
 python/pyspark/sql/connect/dataframe.py | 22 +++++++++++---
 python/pyspark/sql/connect/functions.py | 31 +++++--------------
 4 files changed, 55 insertions(+), 54 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index e78c4de0f70..832b5648676 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -21,12 +21,13 @@ import urllib.parse
 import uuid
 from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, 
NoReturn, cast
 
+import pandas as pd
+import pyarrow as pa
+
 import google.protobuf.message
 from grpc_status import rpc_status
 import grpc
-import pandas
 from google.protobuf import text_format
-import pyarrow as pa
 from google.rpc import error_details_pb2
 
 import pyspark.sql.connect.proto as pb2
@@ -406,11 +407,22 @@ class SparkConnectClient(object):
             for x in metrics.metrics
         ]
 
-    def to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
+    def to_table(self, plan: pb2.Plan) -> "pa.Table":
+        logger.info(f"Executing plan {self._proto_to_string(plan)}")
+        req = self._execute_plan_request_with_metadata()
+        req.plan.CopyFrom(plan)
+        table, _ = self._execute_and_fetch(req)
+        return table
+
+    def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        return self._execute_and_fetch(req)
+        table, metrics = self._execute_and_fetch(req)
+        pdf = table.to_pandas()
+        if len(metrics) > 0:
+            pdf.attrs["metrics"] = metrics
+        return pdf
 
     def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> 
DataType:
         return types.proto_schema_to_pyspark_data_type(schema)
@@ -521,10 +533,6 @@ class SparkConnectClient(object):
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
-    def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) 
-> "pandas.DataFrame":
-        with pa.ipc.open_stream(arrow_batch.data) as rd:
-            return rd.read_pandas()
-
     def _execute(self, req: pb2.ExecutePlanRequest) -> None:
         """
         Execute the passed request `req` and drop all results.
@@ -546,12 +554,14 @@ class SparkConnectClient(object):
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
-    def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> 
"pandas.DataFrame":
+    def _execute_and_fetch(
+        self, req: pb2.ExecutePlanRequest
+    ) -> Tuple["pa.Table", List[PlanMetrics]]:
         logger.info("ExecuteAndFetch")
-        import pandas as pd
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-        result_dfs = []
+
+        batches: List[pa.RecordBatch] = []
 
         try:
             for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
@@ -567,25 +577,21 @@ class SparkConnectClient(object):
                         f"Received arrow batch rows={b.arrow_batch.row_count} "
                         f"size={len(b.arrow_batch.data)}"
                     )
-                    pb = self._process_batch(b.arrow_batch)
-                    result_dfs.append(pb)
+
+                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
+                        for batch in reader:
+                            assert isinstance(batch, pa.RecordBatch)
+                            batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
-        assert len(result_dfs) > 0
+        assert len(batches) > 0
 
-        df = pd.concat(result_dfs)
+        table = pa.Table.from_batches(batches=batches)
 
-        # pd.concat generates non-consecutive index like:
-        #   Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
-        # set it to RangeIndex to be consistent with pyspark
-        n = len(df)
-        df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True)
+        metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None 
else []
 
-        # Attach the metrics to the DataFrame attributes.
-        if m is not None:
-            df.attrs["metrics"] = self._build_metrics(m)
-        return df
+        return table, metrics
 
     def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 6fda15e084a..4d0b3de322d 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -448,8 +448,6 @@ def _test() -> None:
         del pyspark.sql.connect.column.Column.dropFields.__doc__
         # TODO(SPARK-41772): Enable 
pyspark.sql.connect.column.Column.withField doctest
         del pyspark.sql.connect.column.Column.withField.__doc__
-        # TODO(SPARK-41815): Column.isNull returns nan instead of None
-        del pyspark.sql.connect.column.Column.isNull.__doc__
         # TODO(SPARK-41746): SparkSession.createDataFrame does not support 
nested datatypes
         del pyspark.sql.connect.column.Column.getField.__doc__
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index b9d613870ab..fdb75d377b7 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1016,11 +1016,23 @@ class DataFrame:
         return ""
 
     def collect(self) -> List[Row]:
-        pdf = self.toPandas()
-        if pdf is not None:
-            return list(pdf.apply(lambda row: Row(**row), axis=1))
-        else:
-            return []
+        if self._plan is None:
+            raise Exception("Cannot collect on empty plan.")
+        if self._session is None:
+            raise Exception("Cannot collect on empty session.")
+        query = self._plan.to_proto(self._session.client)
+        table = self._session.client.to_table(query)
+
+        rows: List[Row] = []
+        for row in table.to_pylist():
+            _dict = {}
+            for k, v in row.items():
+                if isinstance(v, bytes):
+                    _dict[k] = bytearray(v)
+                else:
+                    _dict[k] = v
+            rows.append(Row(**_dict))
+        return rows
 
     collect.__doc__ = PySparkDataFrame.collect.__doc__
 
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 77c7db2d808..965a9a5331e 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1872,7 +1872,7 @@ translate.__doc__ = pysparkfuncs.translate.__doc__
 
 
 # Date/Timestamp functions
-# TODO(SPARK-41283): Resolve dtypes inconsistencies for:
+# TODO(SPARK-41455): Resolve dtypes inconsistencies for:
 #     to_timestamp, from_utc_timestamp, to_utc_timestamp,
 #     timestamp_seconds, current_timestamp, date_trunc
 
@@ -2347,33 +2347,18 @@ def _test() -> None:
         # Spark Connect does not support Spark Context but the test depends on 
that.
         del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
 
-        # TODO(SPARK-41833): fix collect() output
-        del pyspark.sql.connect.functions.array.__doc__
-        del pyspark.sql.connect.functions.array_distinct.__doc__
-        del pyspark.sql.connect.functions.array_except.__doc__
-        del pyspark.sql.connect.functions.array_intersect.__doc__
-        del pyspark.sql.connect.functions.array_remove.__doc__
-        del pyspark.sql.connect.functions.array_repeat.__doc__
-        del pyspark.sql.connect.functions.array_sort.__doc__
-        del pyspark.sql.connect.functions.array_union.__doc__
-        del pyspark.sql.connect.functions.collect_list.__doc__
-        del pyspark.sql.connect.functions.collect_set.__doc__
-        del pyspark.sql.connect.functions.concat.__doc__
+        # TODO(SPARK-41880): Function `from_json` should support non-literal 
expression
+        # TODO(SPARK-41879): `DataFrame.collect` should support nested types
+        del pyspark.sql.connect.functions.struct.__doc__
         del pyspark.sql.connect.functions.create_map.__doc__
-        del pyspark.sql.connect.functions.date_trunc.__doc__
-        del pyspark.sql.connect.functions.from_utc_timestamp.__doc__
         del pyspark.sql.connect.functions.from_csv.__doc__
         del pyspark.sql.connect.functions.from_json.__doc__
-        del pyspark.sql.connect.functions.isnull.__doc__
-        del pyspark.sql.connect.functions.reverse.__doc__
-        del pyspark.sql.connect.functions.sequence.__doc__
-        del pyspark.sql.connect.functions.slice.__doc__
-        del pyspark.sql.connect.functions.sort_array.__doc__
-        del pyspark.sql.connect.functions.split.__doc__
-        del pyspark.sql.connect.functions.struct.__doc__
+
+        # TODO(SPARK-41455): Resolve dtypes inconsistencies of date/timestamp 
functions
         del pyspark.sql.connect.functions.to_timestamp.__doc__
         del pyspark.sql.connect.functions.to_utc_timestamp.__doc__
-        del pyspark.sql.connect.functions.unhex.__doc__
+        del pyspark.sql.connect.functions.date_trunc.__doc__
+        del pyspark.sql.connect.functions.from_utc_timestamp.__doc__
 
         # TODO(SPARK-41825): Dataframe.show formatting int as double
         del pyspark.sql.connect.functions.coalesce.__doc__


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to