This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new df858d3c3a3 [SPARK-42998][CONNECT][PYTHON] Fix DataFrame.collect with null struct df858d3c3a3 is described below commit df858d3c3a3d7652a92c4fe8ac058999f9fa17ca Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Sat Apr 1 09:34:55 2023 +0800 [SPARK-42998][CONNECT][PYTHON] Fix DataFrame.collect with null struct ### What changes were proposed in this pull request? Fix `DataFrame.collect` with null struct. ### Why are the changes needed? There is a behavior difference when collecting `null` struct: In Spark Connect: ```py >>> df = spark.sql("values (1, struct('a' as x)), (2, struct(null as x)), (null, null) as t(a, b)") >>> df.printSchema() root |-- a: integer (nullable = true) |-- b: struct (nullable = true) | |-- x: string (nullable = true) >>> df.show() +----+------+ | a| b| +----+------+ | 1| {a}| | 2|{null}| |null| null| +----+------+ >>> df.collect() [Row(a=1, b=Row(x='a')), Row(a=2, b=Row(x=None)), Row(a=None, b=<Row()>)] ``` whereas PySpark: ```py >>> df.collect() [Row(a=1, b=Row(x='a')), Row(a=2, b=Row(x=None)), Row(a=None, b=None)] ``` ### Does this PR introduce _any_ user-facing change? The behavior fix. ### How was this patch tested? Added/modified the related tests. Closes #40627 from ueshin/issues/SPARK-42998/null_struct. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit 74cddcfda3ac4779de80696cdae2ba64d53fc635) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/conversion.py | 4 +- .../sql/tests/connect/test_connect_basic.py | 99 +++++++++++++++------- 2 files changed, 72 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index ba488d4d04e..99e4a477d87 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -327,9 +327,9 @@ class ArrowTableToRowsConversion: ArrowTableToRowsConversion._need_converter(f.dataType) for f in dataType.fields ) - def convert_struct(value: Any) -> Row: + def convert_struct(value: Any) -> Any: if value is None: - return Row() + return None else: assert isinstance(value, dict) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 1e798994746..8e3a12dc678 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2902,38 +2902,79 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(cdf4.collect(), sdf4.collect()) def test_array_has_nullable(self): - schema_array_false = StructType().add("arr", ArrayType(IntegerType(), False)) - cdf1 = self.connect.createDataFrame([Row([1, 2]), Row([3])], schema=schema_array_false) - sdf1 = self.spark.createDataFrame([Row([1, 2]), Row([3])], schema=schema_array_false) - self.assertEqual(cdf1.schema, sdf1.schema) - self.assertEqual(cdf1.collect(), sdf1.collect()) - - schema_array_true = StructType().add("arr", ArrayType(IntegerType(), True)) - cdf2 = self.connect.createDataFrame([Row([1, None]), Row([3])], schema=schema_array_true) - sdf2 = self.spark.createDataFrame([Row([1, None]), Row([3])], schema=schema_array_true) - self.assertEqual(cdf2.schema, sdf2.schema) - self.assertEqual(cdf2.collect(), sdf2.collect()) + for schema, data in [ + ( + StructType().add("arr", ArrayType(IntegerType(), False), True), + [Row([1, 2]), Row([3]), Row(None)], + ), + ( + StructType().add("arr", ArrayType(IntegerType(), True), True), + [Row([1, None]), Row([3]), Row(None)], + ), + ( + StructType().add("arr", ArrayType(IntegerType(), False), False), + [Row([1, 2]), Row([3])], + ), + ( + StructType().add("arr", ArrayType(IntegerType(), True), False), + [Row([1, None]), Row([3])], + ), + ]: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_map_has_nullable(self): - schema_map_false = StructType().add("map", MapType(StringType(), IntegerType(), False)) - cdf1 = self.connect.createDataFrame( - [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false - ) - sdf1 = self.spark.createDataFrame( - [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false - ) - self.assertEqual(cdf1.schema, sdf1.schema) - self.assertEqual(cdf1.collect(), sdf1.collect()) + for schema, data in [ + ( + StructType().add("map", MapType(StringType(), IntegerType(), False), True), + [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)], + ), + ( + StructType().add("map", MapType(StringType(), IntegerType(), True), True), + [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)], + ), + ( + StructType().add("map", MapType(StringType(), IntegerType(), False), False), + [Row({"a": 1, "b": 2}), Row({"a": 3})], + ), + ( + StructType().add("map", MapType(StringType(), IntegerType(), True), False), + [Row({"a": 1, "b": None}), Row({"a": 3})], + ), + ]: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) - schema_map_true = StructType().add("map", MapType(StringType(), IntegerType(), True)) - cdf2 = self.connect.createDataFrame( - [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true - ) - sdf2 = self.spark.createDataFrame( - [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true - ) - self.assertEqual(cdf2.schema, sdf2.schema) - self.assertEqual(cdf2.collect(), sdf2.collect()) + def test_struct_has_nullable(self): + for schema, data in [ + ( + StructType().add("struct", StructType().add("i", IntegerType(), False), True), + [Row(Row(1)), Row(Row(2)), Row(None)], + ), + ( + StructType().add("struct", StructType().add("i", IntegerType(), True), True), + [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)], + ), + ( + StructType().add("struct", StructType().add("i", IntegerType(), False), False), + [Row(Row(1)), Row(Row(2))], + ), + ( + StructType().add("struct", StructType().add("i", IntegerType(), True), False), + [Row(Row(1)), Row(Row(2)), Row(Row(None))], + ), + ]: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_large_client_data(self): # SPARK-42816 support more than 4MB message size. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org