This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 731b89d5914 [SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON] Make `DataFrame.collect` handle None/NaN/Array/Binary porperly 731b89d5914 is described below commit 731b89d59143adb8a4ab3d16dd9f0e08c799abf2 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Jan 5 08:52:08 2023 +0900 [SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON] Make `DataFrame.collect` handle None/NaN/Array/Binary porperly ### What changes were proposed in this pull request? Existing `DataFrame.collect` directly collect coming Arrow batches into a Pandas DataFrame, and then convert each series into a Row, which is problematic since it can not correctly handle None/NaN/Arrays/Binary/etc. This PR refactor `DataFrame.collect` by directly building rows from the raw Arrow Table, in order to support: 1, None/NaN values; 2, ArrayType 3, BinaryType ### Why are the changes needed? To be consistent with PySpark ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? enabled doctests Closes #39386 from zhengruifeng/connect_fix_41833. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/client.py | 54 ++++++++++++++++++--------------- python/pyspark/sql/connect/column.py | 2 -- python/pyspark/sql/connect/dataframe.py | 22 +++++++++++--- python/pyspark/sql/connect/functions.py | 31 +++++-------------- 4 files changed, 55 insertions(+), 54 deletions(-) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index e78c4de0f70..832b5648676 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -21,12 +21,13 @@ import urllib.parse import uuid from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, NoReturn, cast +import pandas as pd +import pyarrow as pa + import google.protobuf.message from grpc_status import rpc_status import grpc -import pandas from google.protobuf import text_format -import pyarrow as pa from google.rpc import error_details_pb2 import pyspark.sql.connect.proto as pb2 @@ -406,11 +407,22 @@ class SparkConnectClient(object): for x in metrics.metrics ] - def to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": + def to_table(self, plan: pb2.Plan) -> "pa.Table": + logger.info(f"Executing plan {self._proto_to_string(plan)}") + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + table, _ = self._execute_and_fetch(req) + return table + + def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame": logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - return self._execute_and_fetch(req) + table, metrics = self._execute_and_fetch(req) + pdf = table.to_pandas() + if len(metrics) > 0: + pdf.attrs["metrics"] = metrics + return pdf def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: return types.proto_schema_to_pyspark_data_type(schema) @@ -521,10 +533,6 @@ class SparkConnectClient(object): except grpc.RpcError as rpc_error: self._handle_error(rpc_error) - def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) -> "pandas.DataFrame": - with pa.ipc.open_stream(arrow_batch.data) as rd: - return rd.read_pandas() - def _execute(self, req: pb2.ExecutePlanRequest) -> None: """ Execute the passed request `req` and drop all results. @@ -546,12 +554,14 @@ class SparkConnectClient(object): except grpc.RpcError as rpc_error: self._handle_error(rpc_error) - def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> "pandas.DataFrame": + def _execute_and_fetch( + self, req: pb2.ExecutePlanRequest + ) -> Tuple["pa.Table", List[PlanMetrics]]: logger.info("ExecuteAndFetch") - import pandas as pd m: Optional[pb2.ExecutePlanResponse.Metrics] = None - result_dfs = [] + + batches: List[pa.RecordBatch] = [] try: for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): @@ -567,25 +577,21 @@ class SparkConnectClient(object): f"Received arrow batch rows={b.arrow_batch.row_count} " f"size={len(b.arrow_batch.data)}" ) - pb = self._process_batch(b.arrow_batch) - result_dfs.append(pb) + + with pa.ipc.open_stream(b.arrow_batch.data) as reader: + for batch in reader: + assert isinstance(batch, pa.RecordBatch) + batches.append(batch) except grpc.RpcError as rpc_error: self._handle_error(rpc_error) - assert len(result_dfs) > 0 + assert len(batches) > 0 - df = pd.concat(result_dfs) + table = pa.Table.from_batches(batches=batches) - # pd.concat generates non-consecutive index like: - # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64') - # set it to RangeIndex to be consistent with pyspark - n = len(df) - df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True) + metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else [] - # Attach the metrics to the DataFrame attributes. - if m is not None: - df.attrs["metrics"] = self._build_metrics(m) - return df + return table, metrics def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 6fda15e084a..4d0b3de322d 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -448,8 +448,6 @@ def _test() -> None: del pyspark.sql.connect.column.Column.dropFields.__doc__ # TODO(SPARK-41772): Enable pyspark.sql.connect.column.Column.withField doctest del pyspark.sql.connect.column.Column.withField.__doc__ - # TODO(SPARK-41815): Column.isNull returns nan instead of None - del pyspark.sql.connect.column.Column.isNull.__doc__ # TODO(SPARK-41746): SparkSession.createDataFrame does not support nested datatypes del pyspark.sql.connect.column.Column.getField.__doc__ diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b9d613870ab..fdb75d377b7 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1016,11 +1016,23 @@ class DataFrame: return "" def collect(self) -> List[Row]: - pdf = self.toPandas() - if pdf is not None: - return list(pdf.apply(lambda row: Row(**row), axis=1)) - else: - return [] + if self._plan is None: + raise Exception("Cannot collect on empty plan.") + if self._session is None: + raise Exception("Cannot collect on empty session.") + query = self._plan.to_proto(self._session.client) + table = self._session.client.to_table(query) + + rows: List[Row] = [] + for row in table.to_pylist(): + _dict = {} + for k, v in row.items(): + if isinstance(v, bytes): + _dict[k] = bytearray(v) + else: + _dict[k] = v + rows.append(Row(**_dict)) + return rows collect.__doc__ = PySparkDataFrame.collect.__doc__ diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 77c7db2d808..965a9a5331e 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1872,7 +1872,7 @@ translate.__doc__ = pysparkfuncs.translate.__doc__ # Date/Timestamp functions -# TODO(SPARK-41283): Resolve dtypes inconsistencies for: +# TODO(SPARK-41455): Resolve dtypes inconsistencies for: # to_timestamp, from_utc_timestamp, to_utc_timestamp, # timestamp_seconds, current_timestamp, date_trunc @@ -2347,33 +2347,18 @@ def _test() -> None: # Spark Connect does not support Spark Context but the test depends on that. del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__ - # TODO(SPARK-41833): fix collect() output - del pyspark.sql.connect.functions.array.__doc__ - del pyspark.sql.connect.functions.array_distinct.__doc__ - del pyspark.sql.connect.functions.array_except.__doc__ - del pyspark.sql.connect.functions.array_intersect.__doc__ - del pyspark.sql.connect.functions.array_remove.__doc__ - del pyspark.sql.connect.functions.array_repeat.__doc__ - del pyspark.sql.connect.functions.array_sort.__doc__ - del pyspark.sql.connect.functions.array_union.__doc__ - del pyspark.sql.connect.functions.collect_list.__doc__ - del pyspark.sql.connect.functions.collect_set.__doc__ - del pyspark.sql.connect.functions.concat.__doc__ + # TODO(SPARK-41880): Function `from_json` should support non-literal expression + # TODO(SPARK-41879): `DataFrame.collect` should support nested types + del pyspark.sql.connect.functions.struct.__doc__ del pyspark.sql.connect.functions.create_map.__doc__ - del pyspark.sql.connect.functions.date_trunc.__doc__ - del pyspark.sql.connect.functions.from_utc_timestamp.__doc__ del pyspark.sql.connect.functions.from_csv.__doc__ del pyspark.sql.connect.functions.from_json.__doc__ - del pyspark.sql.connect.functions.isnull.__doc__ - del pyspark.sql.connect.functions.reverse.__doc__ - del pyspark.sql.connect.functions.sequence.__doc__ - del pyspark.sql.connect.functions.slice.__doc__ - del pyspark.sql.connect.functions.sort_array.__doc__ - del pyspark.sql.connect.functions.split.__doc__ - del pyspark.sql.connect.functions.struct.__doc__ + + # TODO(SPARK-41455): Resolve dtypes inconsistencies of date/timestamp functions del pyspark.sql.connect.functions.to_timestamp.__doc__ del pyspark.sql.connect.functions.to_utc_timestamp.__doc__ - del pyspark.sql.connect.functions.unhex.__doc__ + del pyspark.sql.connect.functions.date_trunc.__doc__ + del pyspark.sql.connect.functions.from_utc_timestamp.__doc__ # TODO(SPARK-41825): Dataframe.show formatting int as double del pyspark.sql.connect.functions.coalesce.__doc__ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org