This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new df858d3c3a3 [SPARK-42998][CONNECT][PYTHON] Fix DataFrame.collect with 
null struct
df858d3c3a3 is described below

commit df858d3c3a3d7652a92c4fe8ac058999f9fa17ca
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Sat Apr 1 09:34:55 2023 +0800

    [SPARK-42998][CONNECT][PYTHON] Fix DataFrame.collect with null struct
    
    ### What changes were proposed in this pull request?
    
    Fix `DataFrame.collect` with null struct.
    
    ### Why are the changes needed?
    
    There is a behavior difference when collecting `null` struct:
    
    In Spark Connect:
    
    ```py
    >>> df = spark.sql("values (1, struct('a' as x)), (2, struct(null as x)), 
(null, null) as t(a, b)")
    >>> df.printSchema()
    root
     |-- a: integer (nullable = true)
     |-- b: struct (nullable = true)
     |    |-- x: string (nullable = true)
    >>> df.show()
    +----+------+
    |   a|     b|
    +----+------+
    |   1|   {a}|
    |   2|{null}|
    |null|  null|
    +----+------+
    
    >>> df.collect()
    [Row(a=1, b=Row(x='a')), Row(a=2, b=Row(x=None)), Row(a=None, b=<Row()>)]
    ```
    
    whereas PySpark:
    
    ```py
    >>> df.collect()
    [Row(a=1, b=Row(x='a')), Row(a=2, b=Row(x=None)), Row(a=None, b=None)]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The behavior fix.
    
    ### How was this patch tested?
    
    Added/modified the related tests.
    
    Closes #40627 from ueshin/issues/SPARK-42998/null_struct.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit 74cddcfda3ac4779de80696cdae2ba64d53fc635)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           |  4 +-
 .../sql/tests/connect/test_connect_basic.py        | 99 +++++++++++++++-------
 2 files changed, 72 insertions(+), 31 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index ba488d4d04e..99e4a477d87 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -327,9 +327,9 @@ class ArrowTableToRowsConversion:
                 ArrowTableToRowsConversion._need_converter(f.dataType) for f 
in dataType.fields
             )
 
-            def convert_struct(value: Any) -> Row:
+            def convert_struct(value: Any) -> Any:
                 if value is None:
-                    return Row()
+                    return None
                 else:
                     assert isinstance(value, dict)
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 1e798994746..8e3a12dc678 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2902,38 +2902,79 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf4.collect(), sdf4.collect())
 
     def test_array_has_nullable(self):
-        schema_array_false = StructType().add("arr", ArrayType(IntegerType(), 
False))
-        cdf1 = self.connect.createDataFrame([Row([1, 2]), Row([3])], 
schema=schema_array_false)
-        sdf1 = self.spark.createDataFrame([Row([1, 2]), Row([3])], 
schema=schema_array_false)
-        self.assertEqual(cdf1.schema, sdf1.schema)
-        self.assertEqual(cdf1.collect(), sdf1.collect())
-
-        schema_array_true = StructType().add("arr", ArrayType(IntegerType(), 
True))
-        cdf2 = self.connect.createDataFrame([Row([1, None]), Row([3])], 
schema=schema_array_true)
-        sdf2 = self.spark.createDataFrame([Row([1, None]), Row([3])], 
schema=schema_array_true)
-        self.assertEqual(cdf2.schema, sdf2.schema)
-        self.assertEqual(cdf2.collect(), sdf2.collect())
+        for schema, data in [
+            (
+                StructType().add("arr", ArrayType(IntegerType(), False), True),
+                [Row([1, 2]), Row([3]), Row(None)],
+            ),
+            (
+                StructType().add("arr", ArrayType(IntegerType(), True), True),
+                [Row([1, None]), Row([3]), Row(None)],
+            ),
+            (
+                StructType().add("arr", ArrayType(IntegerType(), False), 
False),
+                [Row([1, 2]), Row([3])],
+            ),
+            (
+                StructType().add("arr", ArrayType(IntegerType(), True), False),
+                [Row([1, None]), Row([3])],
+            ),
+        ]:
+            with self.subTest(schema=schema):
+                cdf = self.connect.createDataFrame(data, schema=schema)
+                sdf = self.spark.createDataFrame(data, schema=schema)
+                self.assertEqual(cdf.schema, sdf.schema)
+                self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_map_has_nullable(self):
-        schema_map_false = StructType().add("map", MapType(StringType(), 
IntegerType(), False))
-        cdf1 = self.connect.createDataFrame(
-            [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
-        )
-        sdf1 = self.spark.createDataFrame(
-            [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
-        )
-        self.assertEqual(cdf1.schema, sdf1.schema)
-        self.assertEqual(cdf1.collect(), sdf1.collect())
+        for schema, data in [
+            (
+                StructType().add("map", MapType(StringType(), IntegerType(), 
False), True),
+                [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
+            ),
+            (
+                StructType().add("map", MapType(StringType(), IntegerType(), 
True), True),
+                [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
+            ),
+            (
+                StructType().add("map", MapType(StringType(), IntegerType(), 
False), False),
+                [Row({"a": 1, "b": 2}), Row({"a": 3})],
+            ),
+            (
+                StructType().add("map", MapType(StringType(), IntegerType(), 
True), False),
+                [Row({"a": 1, "b": None}), Row({"a": 3})],
+            ),
+        ]:
+            with self.subTest(schema=schema):
+                cdf = self.connect.createDataFrame(data, schema=schema)
+                sdf = self.spark.createDataFrame(data, schema=schema)
+                self.assertEqual(cdf.schema, sdf.schema)
+                self.assertEqual(cdf.collect(), sdf.collect())
 
-        schema_map_true = StructType().add("map", MapType(StringType(), 
IntegerType(), True))
-        cdf2 = self.connect.createDataFrame(
-            [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
-        )
-        sdf2 = self.spark.createDataFrame(
-            [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
-        )
-        self.assertEqual(cdf2.schema, sdf2.schema)
-        self.assertEqual(cdf2.collect(), sdf2.collect())
+    def test_struct_has_nullable(self):
+        for schema, data in [
+            (
+                StructType().add("struct", StructType().add("i", 
IntegerType(), False), True),
+                [Row(Row(1)), Row(Row(2)), Row(None)],
+            ),
+            (
+                StructType().add("struct", StructType().add("i", 
IntegerType(), True), True),
+                [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
+            ),
+            (
+                StructType().add("struct", StructType().add("i", 
IntegerType(), False), False),
+                [Row(Row(1)), Row(Row(2))],
+            ),
+            (
+                StructType().add("struct", StructType().add("i", 
IntegerType(), True), False),
+                [Row(Row(1)), Row(Row(2)), Row(Row(None))],
+            ),
+        ]:
+            with self.subTest(schema=schema):
+                cdf = self.connect.createDataFrame(data, schema=schema)
+                sdf = self.spark.createDataFrame(data, schema=schema)
+                self.assertEqual(cdf.schema, sdf.schema)
+                self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_large_client_data(self):
         # SPARK-42816 support more than 4MB message size.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to