This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 50d9f94225b [SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate 
field names and refactor
50d9f94225b is described below

commit 50d9f94225ba0f127ceaebfea465ac450b017f86
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Fri Apr 21 14:57:58 2023 +0900

    [SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate field names and 
refactor
    
    ### What changes were proposed in this pull request?
    
    Fixes deduplicate field names, and refactor to use the same renaming rule 
between `ArrowTableToRowsConversion.convert` and 
`LocalDataToArrowConversion.convert`.
    
    ### Why are the changes needed?
    
    If there is a duplicated field name in a separate position, it fails to 
deduplicate and returns a wrong result.
    
    ```py
    >>> from pyspark.sql.types import *
    >>> data = [
    ...     Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")),
    ...     Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")),
    ... ]
    >>> schema = (
    ...     StructType()
    ...     .add("struct", StructType().add("x", StringType()).add("x", 
IntegerType()))
    ...     .add(
    ...         "struct",
    ...         StructType()
    ...         .add("a", IntegerType())
    ...         .add("x", IntegerType())
    ...         .add("x", StringType())
    ...         .add("y", IntegerType())
    ...         .add("y", StringType())
    ...         .add("x", StringType()),
    ...     )
    ... )
    >>> df = spark.createDataFrame(data, schema=schema)
    >>>
    >>> df.collect()
    [Row(struct=Row(x='a', x=1), struct=Row(a=2, x=None, x=None, y=4, y='c', 
x=None)), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=None, x=None, y=9, 
y='y', x=None))]
    ```
    
    It should be:
    
    ```py
    >>> df.collect()
    [Row(struct=Row(x='a', x=1), struct=Row(a=2, x=3, x='b', y=4, y='c', 
x='d')), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=8, x='x', y=9, y='y', 
x='z'))]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Updated the related test.
    
    Closes #40888 from ueshin/issues/SPARK-43055/fix_dedup.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/conversion.py   | 102 +++++++++++++++++------------
 python/pyspark/sql/tests/test_dataframe.py |   8 ++-
 2 files changed, 65 insertions(+), 45 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index a6fe0c00e09..16679e80205 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -40,7 +40,6 @@ from pyspark.sql.types import (
     DecimalType,
     StringType,
     UserDefinedType,
-    cast,
 )
 
 from pyspark.storagelevel import StorageLevel
@@ -50,7 +49,6 @@ import pyspark.sql.connect.proto as pb2
 from typing import (
     Any,
     Callable,
-    Dict,
     Sequence,
     List,
 )
@@ -104,6 +102,7 @@ class LocalDataToArrowConversion:
         elif isinstance(dataType, StructType):
 
             field_names = dataType.fieldNames()
+            dedup_field_names = _dedup_names(dataType.names)
 
             field_convs = [
                 LocalDataToArrowConversion._create_converter(field.dataType)
@@ -123,7 +122,7 @@ class LocalDataToArrowConversion:
                         value = value.__dict__
                     if isinstance(value, dict):
                         for i, field in enumerate(field_names):
-                            _dict[f"col_{i}"] = 
field_convs[i](value.get(field))
+                            _dict[dedup_field_names[i]] = 
field_convs[i](value.get(field))
                     else:
                         if len(value) != len(field_names):
                             raise ValueError(
@@ -131,7 +130,7 @@ class LocalDataToArrowConversion:
                                 f"new values have {len(value)} elements"
                             )
                         for i in range(len(field_names)):
-                            _dict[f"col_{i}"] = field_convs[i](value[i])
+                            _dict[dedup_field_names[i]] = 
field_convs[i](value[i])
 
                     return _dict
 
@@ -290,26 +289,16 @@ class LocalDataToArrowConversion:
                 for i in range(len(column_names)):
                     pylist[i].append(column_convs[i](item[i]))
 
-        def normalize(dt: DataType) -> DataType:
-            if isinstance(dt, StructType):
-                return StructType(
-                    [
-                        StructField(f"col_{i}", normalize(field.dataType), 
nullable=field.nullable)
-                        for i, field in enumerate(dt.fields)
-                    ]
-                )
-            elif isinstance(dt, ArrayType):
-                return ArrayType(normalize(dt.elementType), 
containsNull=dt.containsNull)
-            elif isinstance(dt, MapType):
-                return MapType(
-                    normalize(dt.keyType),
-                    normalize(dt.valueType),
-                    valueContainsNull=dt.valueContainsNull,
-                )
-            else:
-                return dt
-
-        pa_schema = to_arrow_schema(cast(StructType, normalize(schema)))
+        pa_schema = to_arrow_schema(
+            StructType(
+                [
+                    StructField(
+                        field.name, _deduplicate_field_names(field.dataType), 
field.nullable
+                    )
+                    for field in schema.fields
+                ]
+            )
+        )
 
         return pa.Table.from_arrays(pylist, schema=pa_schema)
 
@@ -355,25 +344,7 @@ class ArrowTableToRowsConversion:
         elif isinstance(dataType, StructType):
 
             field_names = dataType.names
-
-            if len(set(field_names)) == len(field_names):
-                dedup_field_names = field_names
-            else:
-                gen_new_name: Dict[str, Callable[[], str]] = {}
-                for name, group in itertools.groupby(dataType.names):
-                    if len(list(group)) > 1:
-
-                        def _gen(_name: str) -> Callable[[], str]:
-                            _i = itertools.count()
-                            return lambda: f"{_name}_{next(_i)}"
-
-                    else:
-
-                        def _gen(_name: str) -> Callable[[], str]:
-                            return lambda: _name
-
-                    gen_new_name[name] = _gen(name)
-                dedup_field_names = [gen_new_name[name]() for name in 
dataType.names]
+            dedup_field_names = _dedup_names(field_names)
 
             field_convs = [
                 ArrowTableToRowsConversion._create_converter(f.dataType) for f 
in dataType.fields
@@ -510,3 +481,48 @@ def proto_to_storage_level(storage_level: 
pb2.StorageLevel) -> StorageLevel:
         deserialized=storage_level.deserialized,
         replication=storage_level.replication,
     )
+
+
+def _deduplicate_field_names(dt: DataType) -> DataType:
+    if isinstance(dt, StructType):
+        dedup_field_names = _dedup_names(dt.names)
+
+        return StructType(
+            [
+                StructField(
+                    dedup_field_names[i],
+                    _deduplicate_field_names(field.dataType),
+                    nullable=field.nullable,
+                )
+                for i, field in enumerate(dt.fields)
+            ]
+        )
+    elif isinstance(dt, ArrayType):
+        return ArrayType(_deduplicate_field_names(dt.elementType), 
containsNull=dt.containsNull)
+    elif isinstance(dt, MapType):
+        return MapType(
+            _deduplicate_field_names(dt.keyType),
+            _deduplicate_field_names(dt.valueType),
+            valueContainsNull=dt.valueContainsNull,
+        )
+    else:
+        return dt
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+    if len(set(names)) == len(names):
+        return names
+    else:
+
+        def _gen_dedup(_name: str) -> Callable[[], str]:
+            _i = itertools.count()
+            return lambda: f"{_name}_{next(_i)}"
+
+        def _gen_identity(_name: str) -> Callable[[], str]:
+            return lambda: _name
+
+        gen_new_name = {
+            name: _gen_dedup(name) if len(list(group)) > 1 else 
_gen_identity(name)
+            for name, group in itertools.groupby(sorted(names))
+        }
+        return [gen_new_name[name]() for name in names]
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 5716acbaabc..164b6a22a69 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1710,7 +1710,10 @@ class DataFrameTestsMixin:
         )
 
     def test_duplicate_field_names(self):
-        data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), 
Row(7, 8, "y", 9, "z"))]
+        data = [
+            Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")),
+            Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")),
+        ]
         schema = (
             StructType()
             .add("struct", StructType().add("x", StringType()).add("x", 
IntegerType()))
@@ -1721,7 +1724,8 @@ class DataFrameTestsMixin:
                 .add("x", IntegerType())
                 .add("x", StringType())
                 .add("y", IntegerType())
-                .add("y", StringType()),
+                .add("y", StringType())
+                .add("x", StringType()),
             )
         )
         df = self.spark.createDataFrame(data, schema=schema)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to