This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 803b2854a9e [SPARK-44479][PYTHON][3.5] Fix ArrowStreamPandasUDFSerializer to accept no-column pandas DataFrame 803b2854a9e is described below commit 803b2854a9e82aee4e5691c4a9a697856b963377 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Jul 26 17:54:38 2023 -0700 [SPARK-44479][PYTHON][3.5] Fix ArrowStreamPandasUDFSerializer to accept no-column pandas DataFrame ### What changes were proposed in this pull request? Fixes `ArrowStreamPandasUDFSerializer` to accept no-column pandas DataFrame. ```py >>> def _scalar_f(id): ... return pd.DataFrame(index=id) ... >>> scalar_f = pandas_udf(_scalar_f, returnType=StructType()) >>> df = spark.range(3).withColumn("f", scalar_f(col("id"))) >>> df.printSchema() root |-- id: long (nullable = false) |-- f: struct (nullable = true) >>> df.show() +---+---+ | id| f| +---+---+ | 0| {}| | 1| {}| | 2| {}| +---+---+ ``` ### Why are the changes needed? The above query fails with the following error: ```py >>> df.show() org.apache.spark.api.python.PythonException: Traceback (most recent call last): ... ValueError: not enough values to unpack (expected 2, got 0) ``` ### Does this PR introduce _any_ user-facing change? Yes, Pandas UDF will accept no-column pandas DataFrame. ### How was this patch tested? Added related tests. Closes #42176 from ueshin/issues/SPARK-44479/3.5/empty_schema. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/connect/types.py | 4 ++- python/pyspark/sql/pandas/serializers.py | 31 ++++++++-------------- .../sql/tests/pandas/test_pandas_udf_scalar.py | 23 +++++++++++++++- python/pyspark/sql/tests/test_udtf.py | 11 ++++++-- 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 2a21cdf0675..0db2833d2c1 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -170,6 +170,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: ret.year_month_interval.start_field = data_type.startField ret.year_month_interval.end_field = data_type.endField elif isinstance(data_type, StructType): + struct = pb2.DataType.Struct() for field in data_type.fields: struct_field = pb2.DataType.StructField() struct_field.name = field.name @@ -177,7 +178,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: struct_field.nullable = field.nullable if field.metadata is not None and len(field.metadata) > 0: struct_field.metadata = json.dumps(field.metadata) - ret.struct.fields.append(struct_field) + struct.fields.append(struct_field) + ret.struct.CopyFrom(struct) elif isinstance(data_type, MapType): ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType)) ret.map.value_type.CopyFrom(pyspark_types_to_proto_types(data_type.valueType)) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f22a73cbbef..1d326928e23 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -385,37 +385,28 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ import pyarrow as pa - # Input partition and result pandas.DataFrame empty, make empty Arrays with struct - if len(df) == 0 and len(df.columns) == 0: - arrs_names = [ - (pa.array([], type=field.type), field.name) for field in arrow_struct_type - ] + if len(df.columns) == 0: + return pa.array([{}] * len(df), arrow_struct_type) # Assign result columns by schema name if user labeled with strings - elif self._assign_cols_by_name and any(isinstance(name, str) for name in df.columns): - arrs_names = [ - ( - self._create_array(df[field.name], field.type, arrow_cast=self._arrow_cast), - field.name, - ) + if self._assign_cols_by_name and any(isinstance(name, str) for name in df.columns): + struct_arrs = [ + self._create_array(df[field.name], field.type, arrow_cast=self._arrow_cast) for field in arrow_struct_type ] # Assign result columns by position else: - arrs_names = [ + struct_arrs = [ # the selected series has name '1', so we rename it to field.name # as the name is used by _create_array to provide a meaningful error message - ( - self._create_array( - df[df.columns[i]].rename(field.name), - field.type, - arrow_cast=self._arrow_cast, - ), - field.name, + self._create_array( + df[df.columns[i]].rename(field.name), + field.type, + arrow_cast=self._arrow_cast, ) for i, field in enumerate(arrow_struct_type) ] - struct_arrs, struct_names = zip(*arrs_names) + struct_names = [field.name for field in arrow_struct_type] return pa.StructArray.from_arrays(struct_arrs, struct_names) def _create_batch(self, series): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index fef21224266..7a80547b3fc 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -60,7 +60,7 @@ from pyspark.testing.sqlutils import ( pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: import pandas as pd @@ -528,6 +528,27 @@ class ScalarPandasUDFTestsMixin: self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) self.assertListEqual([i, i + 1], f[1]) + def test_vectorized_udf_struct_empty(self): + df = self.spark.range(3) + return_type = StructType() + + def _scalar_f(id): + return pd.DataFrame(index=id) + + scalar_f = pandas_udf(_scalar_f, returnType=return_type) + + @pandas_udf(returnType=return_type, functionType=PandasUDFType.SCALAR_ITER) + def iter_f(it): + for id in it: + yield _scalar_f(id) + + for f, udf_type in [(scalar_f, "SCALAR"), (iter_f, "SCALAR_ITER")]: + with self.subTest(udf_type=udf_type): + assertDataFrameEqual( + df.withColumn("f", f(col("id"))), + [Row(id=0, f=Row()), Row(id=1, f=Row()), Row(id=2, f=Row())], + ) + def test_vectorized_udf_nested_struct(self): with QuietTest(self.sc): self.check_vectorized_udf_nested_struct() diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index bb84b9d836a..0aa769e506d 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -29,9 +29,8 @@ from pyspark.errors import ( ) from pyspark.rdd import PythonEvalType from pyspark.sql.functions import lit, udf, udtf -from pyspark.sql.types import Row +from pyspark.sql.types import IntegerType, MapType, Row, StringType, StructType from pyspark.testing import assertDataFrameEqual -from pyspark.sql.types import MapType, StringType, IntegerType from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -466,6 +465,14 @@ class BaseUDTFTestsMixin: self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})]) + def test_udtf_with_empty_output_types(self): + @udtf(returnType=StructType()) + class TestUDTF: + def eval(self): + yield tuple() + + assertDataFrameEqual(TestUDTF(), [Row()]) + @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_udtf_with_pandas_input_type(self): import pandas as pd --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org