This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 50d9f94225b [SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate field names and refactor 50d9f94225b is described below commit 50d9f94225ba0f127ceaebfea465ac450b017f86 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Fri Apr 21 14:57:58 2023 +0900 [SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate field names and refactor ### What changes were proposed in this pull request? Fixes deduplicate field names, and refactor to use the same renaming rule between `ArrowTableToRowsConversion.convert` and `LocalDataToArrowConversion.convert`. ### Why are the changes needed? If there is a duplicated field name in a separate position, it fails to deduplicate and returns a wrong result. ```py >>> from pyspark.sql.types import * >>> data = [ ... Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")), ... Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")), ... ] >>> schema = ( ... StructType() ... .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) ... .add( ... "struct", ... StructType() ... .add("a", IntegerType()) ... .add("x", IntegerType()) ... .add("x", StringType()) ... .add("y", IntegerType()) ... .add("y", StringType()) ... .add("x", StringType()), ... ) ... ) >>> df = spark.createDataFrame(data, schema=schema) >>> >>> df.collect() [Row(struct=Row(x='a', x=1), struct=Row(a=2, x=None, x=None, y=4, y='c', x=None)), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=None, x=None, y=9, y='y', x=None))] ``` It should be: ```py >>> df.collect() [Row(struct=Row(x='a', x=1), struct=Row(a=2, x=3, x='b', y=4, y='c', x='d')), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=8, x='x', y=9, y='y', x='z'))] ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Updated the related test. Closes #40888 from ueshin/issues/SPARK-43055/fix_dedup. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/conversion.py | 102 +++++++++++++++++------------ python/pyspark/sql/tests/test_dataframe.py | 8 ++- 2 files changed, 65 insertions(+), 45 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index a6fe0c00e09..16679e80205 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -40,7 +40,6 @@ from pyspark.sql.types import ( DecimalType, StringType, UserDefinedType, - cast, ) from pyspark.storagelevel import StorageLevel @@ -50,7 +49,6 @@ import pyspark.sql.connect.proto as pb2 from typing import ( Any, Callable, - Dict, Sequence, List, ) @@ -104,6 +102,7 @@ class LocalDataToArrowConversion: elif isinstance(dataType, StructType): field_names = dataType.fieldNames() + dedup_field_names = _dedup_names(dataType.names) field_convs = [ LocalDataToArrowConversion._create_converter(field.dataType) @@ -123,7 +122,7 @@ class LocalDataToArrowConversion: value = value.__dict__ if isinstance(value, dict): for i, field in enumerate(field_names): - _dict[f"col_{i}"] = field_convs[i](value.get(field)) + _dict[dedup_field_names[i]] = field_convs[i](value.get(field)) else: if len(value) != len(field_names): raise ValueError( @@ -131,7 +130,7 @@ class LocalDataToArrowConversion: f"new values have {len(value)} elements" ) for i in range(len(field_names)): - _dict[f"col_{i}"] = field_convs[i](value[i]) + _dict[dedup_field_names[i]] = field_convs[i](value[i]) return _dict @@ -290,26 +289,16 @@ class LocalDataToArrowConversion: for i in range(len(column_names)): pylist[i].append(column_convs[i](item[i])) - def normalize(dt: DataType) -> DataType: - if isinstance(dt, StructType): - return StructType( - [ - StructField(f"col_{i}", normalize(field.dataType), nullable=field.nullable) - for i, field in enumerate(dt.fields) - ] - ) - elif isinstance(dt, ArrayType): - return ArrayType(normalize(dt.elementType), containsNull=dt.containsNull) - elif isinstance(dt, MapType): - return MapType( - normalize(dt.keyType), - normalize(dt.valueType), - valueContainsNull=dt.valueContainsNull, - ) - else: - return dt - - pa_schema = to_arrow_schema(cast(StructType, normalize(schema))) + pa_schema = to_arrow_schema( + StructType( + [ + StructField( + field.name, _deduplicate_field_names(field.dataType), field.nullable + ) + for field in schema.fields + ] + ) + ) return pa.Table.from_arrays(pylist, schema=pa_schema) @@ -355,25 +344,7 @@ class ArrowTableToRowsConversion: elif isinstance(dataType, StructType): field_names = dataType.names - - if len(set(field_names)) == len(field_names): - dedup_field_names = field_names - else: - gen_new_name: Dict[str, Callable[[], str]] = {} - for name, group in itertools.groupby(dataType.names): - if len(list(group)) > 1: - - def _gen(_name: str) -> Callable[[], str]: - _i = itertools.count() - return lambda: f"{_name}_{next(_i)}" - - else: - - def _gen(_name: str) -> Callable[[], str]: - return lambda: _name - - gen_new_name[name] = _gen(name) - dedup_field_names = [gen_new_name[name]() for name in dataType.names] + dedup_field_names = _dedup_names(field_names) field_convs = [ ArrowTableToRowsConversion._create_converter(f.dataType) for f in dataType.fields @@ -510,3 +481,48 @@ def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel: deserialized=storage_level.deserialized, replication=storage_level.replication, ) + + +def _deduplicate_field_names(dt: DataType) -> DataType: + if isinstance(dt, StructType): + dedup_field_names = _dedup_names(dt.names) + + return StructType( + [ + StructField( + dedup_field_names[i], + _deduplicate_field_names(field.dataType), + nullable=field.nullable, + ) + for i, field in enumerate(dt.fields) + ] + ) + elif isinstance(dt, ArrayType): + return ArrayType(_deduplicate_field_names(dt.elementType), containsNull=dt.containsNull) + elif isinstance(dt, MapType): + return MapType( + _deduplicate_field_names(dt.keyType), + _deduplicate_field_names(dt.valueType), + valueContainsNull=dt.valueContainsNull, + ) + else: + return dt + + +def _dedup_names(names: List[str]) -> List[str]: + if len(set(names)) == len(names): + return names + else: + + def _gen_dedup(_name: str) -> Callable[[], str]: + _i = itertools.count() + return lambda: f"{_name}_{next(_i)}" + + def _gen_identity(_name: str) -> Callable[[], str]: + return lambda: _name + + gen_new_name = { + name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name) + for name, group in itertools.groupby(sorted(names)) + } + return [gen_new_name[name]() for name in names] diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 5716acbaabc..164b6a22a69 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1710,7 +1710,10 @@ class DataFrameTestsMixin: ) def test_duplicate_field_names(self): - data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), Row(7, 8, "y", 9, "z"))] + data = [ + Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")), + Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")), + ] schema = ( StructType() .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) @@ -1721,7 +1724,8 @@ class DataFrameTestsMixin: .add("x", IntegerType()) .add("x", StringType()) .add("y", IntegerType()) - .add("y", StringType()), + .add("y", StringType()) + .add("x", StringType()), ) ) df = self.spark.createDataFrame(data, schema=schema) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org