This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 02e36dd0f07 [SPARK-44479][PYTHON] Fix ArrowStreamPandasUDFSerializer 
to accept no-column pandas DataFrame
02e36dd0f07 is described below

commit 02e36dd0f077d11a75c6e083489dc1a51c870a0d
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Wed Jul 26 17:53:46 2023 -0700

    [SPARK-44479][PYTHON] Fix ArrowStreamPandasUDFSerializer to accept 
no-column pandas DataFrame
    
    ### What changes were proposed in this pull request?
    
    Fixes `ArrowStreamPandasUDFSerializer` to accept no-column pandas DataFrame.
    
    ```py
    >>> def _scalar_f(id):
    ...   return pd.DataFrame(index=id)
    ...
    >>> scalar_f = pandas_udf(_scalar_f, returnType=StructType())
    >>> df = spark.range(3).withColumn("f", scalar_f(col("id")))
    >>> df.printSchema()
    root
     |-- id: long (nullable = false)
     |-- f: struct (nullable = true)
    
    >>> df.show()
    +---+---+
    | id|  f|
    +---+---+
    |  0| {}|
    |  1| {}|
    |  2| {}|
    +---+---+
    ```
    
    ### Why are the changes needed?
    
    The above query fails with the following error:
    
    ```py
    >>> df.show()
    org.apache.spark.api.python.PythonException: Traceback (most recent call 
last):
    ...
    ValueError: not enough values to unpack (expected 2, got 0)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Pandas UDF will accept no-column pandas DataFrame.
    
    ### How was this patch tested?
    
    Added related tests.
    
    Closes #42161 from ueshin/issues/SPARK-44479/empty_schema.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/connect/types.py                |  4 ++-
 python/pyspark/sql/pandas/serializers.py           | 31 ++++++++--------------
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 23 +++++++++++++++-
 python/pyspark/sql/tests/test_udtf.py              | 11 ++++++--
 4 files changed, 45 insertions(+), 24 deletions(-)

diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 2a21cdf0675..0db2833d2c1 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -170,6 +170,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> 
pb2.DataType:
         ret.year_month_interval.start_field = data_type.startField
         ret.year_month_interval.end_field = data_type.endField
     elif isinstance(data_type, StructType):
+        struct = pb2.DataType.Struct()
         for field in data_type.fields:
             struct_field = pb2.DataType.StructField()
             struct_field.name = field.name
@@ -177,7 +178,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> 
pb2.DataType:
             struct_field.nullable = field.nullable
             if field.metadata is not None and len(field.metadata) > 0:
                 struct_field.metadata = json.dumps(field.metadata)
-            ret.struct.fields.append(struct_field)
+            struct.fields.append(struct_field)
+        ret.struct.CopyFrom(struct)
     elif isinstance(data_type, MapType):
         
ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType))
         
ret.map.value_type.CopyFrom(pyspark_types_to_proto_types(data_type.valueType))
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 90a24197f64..15de00782c6 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -385,37 +385,28 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         """
         import pyarrow as pa
 
-        # Input partition and result pandas.DataFrame empty, make empty Arrays 
with struct
-        if len(df) == 0 and len(df.columns) == 0:
-            arrs_names = [
-                (pa.array([], type=field.type), field.name) for field in 
arrow_struct_type
-            ]
+        if len(df.columns) == 0:
+            return pa.array([{}] * len(df), arrow_struct_type)
         # Assign result columns by schema name if user labeled with strings
-        elif self._assign_cols_by_name and any(isinstance(name, str) for name 
in df.columns):
-            arrs_names = [
-                (
-                    self._create_array(df[field.name], field.type, 
arrow_cast=self._arrow_cast),
-                    field.name,
-                )
+        if self._assign_cols_by_name and any(isinstance(name, str) for name in 
df.columns):
+            struct_arrs = [
+                self._create_array(df[field.name], field.type, 
arrow_cast=self._arrow_cast)
                 for field in arrow_struct_type
             ]
         # Assign result columns by position
         else:
-            arrs_names = [
+            struct_arrs = [
                 # the selected series has name '1', so we rename it to 
field.name
                 # as the name is used by _create_array to provide a meaningful 
error message
-                (
-                    self._create_array(
-                        df[df.columns[i]].rename(field.name),
-                        field.type,
-                        arrow_cast=self._arrow_cast,
-                    ),
-                    field.name,
+                self._create_array(
+                    df[df.columns[i]].rename(field.name),
+                    field.type,
+                    arrow_cast=self._arrow_cast,
                 )
                 for i, field in enumerate(arrow_struct_type)
             ]
 
-        struct_arrs, struct_names = zip(*arrs_names)
+        struct_names = [field.name for field in arrow_struct_type]
         return pa.StructArray.from_arrays(struct_arrs, struct_names)
 
     def _create_batch(self, series):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index fef21224266..7a80547b3fc 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -60,7 +60,7 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
 
 if have_pandas:
     import pandas as pd
@@ -528,6 +528,27 @@ class ScalarPandasUDFTestsMixin:
                 self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
                 self.assertListEqual([i, i + 1], f[1])
 
+    def test_vectorized_udf_struct_empty(self):
+        df = self.spark.range(3)
+        return_type = StructType()
+
+        def _scalar_f(id):
+            return pd.DataFrame(index=id)
+
+        scalar_f = pandas_udf(_scalar_f, returnType=return_type)
+
+        @pandas_udf(returnType=return_type, 
functionType=PandasUDFType.SCALAR_ITER)
+        def iter_f(it):
+            for id in it:
+                yield _scalar_f(id)
+
+        for f, udf_type in [(scalar_f, "SCALAR"), (iter_f, "SCALAR_ITER")]:
+            with self.subTest(udf_type=udf_type):
+                assertDataFrameEqual(
+                    df.withColumn("f", f(col("id"))),
+                    [Row(id=0, f=Row()), Row(id=1, f=Row()), Row(id=2, 
f=Row())],
+                )
+
     def test_vectorized_udf_nested_struct(self):
         with QuietTest(self.sc):
             self.check_vectorized_udf_nested_struct()
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index a6c7b97acc4..688034b9930 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -487,6 +487,14 @@ class BaseUDTFTestsMixin:
 
         self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})])
 
+    def test_udtf_with_empty_output_types(self):
+        @udtf(returnType=StructType())
+        class TestUDTF:
+            def eval(self):
+                yield tuple()
+
+        assertDataFrameEqual(TestUDTF(), [Row()])
+
     @unittest.skipIf(not have_pandas, pandas_requirement_message)
     def test_udtf_with_pandas_input_type(self):
         import pandas as pd
@@ -932,8 +940,7 @@ class BaseUDTFTestsMixin:
                     StructType().add("col0", IntegerType()).add("col1", 
StringType()),
                     [Row(a=1, b="x")],
                 ),
-                # TODO(SPARK-44479): Support Python UDTFs with empty schema
-                # (func(), StructType(), [Row()]),
+                (func(), StructType(), [Row()]),
             ]
         ):
             with self.subTest(query_no=i):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to