This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e55875b0bbe0 [SPARK-48372][SPARK-45716][PYTHON] Implement `StructType.treeString` e55875b0bbe0 is described below commit e55875b0bbe08c435ffcb0ea034ceb95938d8729 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed May 22 15:31:27 2024 +0800 [SPARK-48372][SPARK-45716][PYTHON] Implement `StructType.treeString` ### What changes were proposed in this pull request? Implement `StructType.treeString` ### Why are the changes needed? feature parity, this method is Scala-only before ### Does this PR introduce _any_ user-facing change? yes ``` In [2]: schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>") In [3]: print(schema1.treeString()) root |-- c1: integer (nullable = true) |-- c2: struct (nullable = true) | |-- c3: integer (nullable = true) | |-- c4: struct (nullable = true) | | |-- c5: integer (nullable = true) | | |-- c6: integer (nullable = true) ``` ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46685 from zhengruifeng/py_tree_string. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_types.py | 241 +++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 87 +++++++++++- python/pyspark/sql/utils.py | 54 +++++++- 3 files changed, 380 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 4d6fc499b70b..ec07406b1191 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1170,6 +1170,247 @@ class TypesTestsMixin: ) self.assertEqual(VariantType(), _parse_datatype_string("variant")) + def test_tree_string(self): + schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>") + + self.assertEqual( + schema1.treeString().split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + " | | |-- c5: integer (nullable = true)", + " | | |-- c6: integer (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(-1).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + " | | |-- c5: integer (nullable = true)", + " | | |-- c6: integer (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(0).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + " | | |-- c5: integer (nullable = true)", + " | | |-- c6: integer (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(1).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(2).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(3).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + " | | |-- c5: integer (nullable = true)", + " | | |-- c6: integer (nullable = true)", + "", + ], + ) + self.assertEqual( + schema1.treeString(4).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: struct (nullable = true)", + " | |-- c3: integer (nullable = true)", + " | |-- c4: struct (nullable = true)", + " | | |-- c5: integer (nullable = true)", + " | | |-- c6: integer (nullable = true)", + "", + ], + ) + + schema2 = DataType.fromDDL( + "c1 INT, c2 ARRAY<STRUCT<c3: INT>>, c4 STRUCT<c5: INT, c6: ARRAY<ARRAY<INT>>>" + ) + self.assertEqual( + schema2.treeString(0).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: array (nullable = true)", + " | |-- element: struct (containsNull = true)", + " | | |-- c3: integer (nullable = true)", + " |-- c4: struct (nullable = true)", + " | |-- c5: integer (nullable = true)", + " | |-- c6: array (nullable = true)", + " | | |-- element: array (containsNull = true)", + " | | | |-- element: integer (containsNull = true)", + "", + ], + ) + self.assertEqual( + schema2.treeString(1).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: array (nullable = true)", + " |-- c4: struct (nullable = true)", + "", + ], + ) + self.assertEqual( + schema2.treeString(2).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: array (nullable = true)", + " | |-- element: struct (containsNull = true)", + " |-- c4: struct (nullable = true)", + " | |-- c5: integer (nullable = true)", + " | |-- c6: array (nullable = true)", + "", + ], + ) + self.assertEqual( + schema2.treeString(3).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: array (nullable = true)", + " | |-- element: struct (containsNull = true)", + " | | |-- c3: integer (nullable = true)", + " |-- c4: struct (nullable = true)", + " | |-- c5: integer (nullable = true)", + " | |-- c6: array (nullable = true)", + " | | |-- element: array (containsNull = true)", + "", + ], + ) + self.assertEqual( + schema2.treeString(4).split("\n"), + [ + "root", + " |-- c1: integer (nullable = true)", + " |-- c2: array (nullable = true)", + " | |-- element: struct (containsNull = true)", + " | | |-- c3: integer (nullable = true)", + " |-- c4: struct (nullable = true)", + " | |-- c5: integer (nullable = true)", + " | |-- c6: array (nullable = true)", + " | | |-- element: array (containsNull = true)", + " | | | |-- element: integer (containsNull = true)", + "", + ], + ) + + schema3 = DataType.fromDDL( + "c1 MAP<INT, STRUCT<c2: MAP<INT, INT>>>, c3 STRUCT<c4: MAP<INT, MAP<INT, INT>>>" + ) + self.assertEqual( + schema3.treeString(0).split("\n"), + [ + "root", + " |-- c1: map (nullable = true)", + " | |-- key: integer", + " | |-- value: struct (valueContainsNull = true)", + " | | |-- c2: map (nullable = true)", + " | | | |-- key: integer", + " | | | |-- value: integer (valueContainsNull = true)", + " |-- c3: struct (nullable = true)", + " | |-- c4: map (nullable = true)", + " | | |-- key: integer", + " | | |-- value: map (valueContainsNull = true)", + " | | | |-- key: integer", + " | | | |-- value: integer (valueContainsNull = true)", + "", + ], + ) + self.assertEqual( + schema3.treeString(1).split("\n"), + [ + "root", + " |-- c1: map (nullable = true)", + " |-- c3: struct (nullable = true)", + "", + ], + ) + self.assertEqual( + schema3.treeString(2).split("\n"), + [ + "root", + " |-- c1: map (nullable = true)", + " | |-- key: integer", + " | |-- value: struct (valueContainsNull = true)", + " |-- c3: struct (nullable = true)", + " | |-- c4: map (nullable = true)", + "", + ], + ) + self.assertEqual( + schema3.treeString(3).split("\n"), + [ + "root", + " |-- c1: map (nullable = true)", + " | |-- key: integer", + " | |-- value: struct (valueContainsNull = true)", + " | | |-- c2: map (nullable = true)", + " |-- c3: struct (nullable = true)", + " | |-- c4: map (nullable = true)", + " | | |-- key: integer", + " | | |-- value: map (valueContainsNull = true)", + "", + ], + ) + self.assertEqual( + schema3.treeString(4).split("\n"), + [ + "root", + " |-- c1: map (nullable = true)", + " | |-- key: integer", + " | |-- value: struct (valueContainsNull = true)", + " | | |-- c2: map (nullable = true)", + " | | | |-- key: integer", + " | | | |-- value: integer (valueContainsNull = true)", + " |-- c3: struct (nullable = true)", + " | |-- c4: map (nullable = true)", + " | | |-- key: integer", + " | | |-- value: map (valueContainsNull = true)", + " | | | |-- key: integer", + " | | | |-- value: integer (valueContainsNull = true)", + "", + ], + ) + def test_metadata_null(self): schema = StructType( [ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3f1e9ee83f10..fa98d09a9af9 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -47,7 +47,12 @@ from typing import ( from pyspark.util import is_remote_only from pyspark.serializers import CloudPickleSerializer -from pyspark.sql.utils import has_numpy, get_active_spark_context +from pyspark.sql.utils import ( + has_numpy, + get_active_spark_context, + escape_meta_characters, + StringConcat, +) from pyspark.sql.variant_utils import VariantUtils from pyspark.errors import ( PySparkNotImplementedError, @@ -99,6 +104,8 @@ __all__ = [ "VariantVal", ] +_JVM_INT_MAX: int = (1 << 31) - 1 + class DataType: """Base class for data types.""" @@ -199,6 +206,17 @@ class DataType: assert len(schema) == 1 return schema[0].dataType + @classmethod + def _data_type_build_formatted_string( + cls, + dataType: "DataType", + prefix: str, + stringConcat: StringConcat, + maxDepth: int, + ) -> None: + if isinstance(dataType, (ArrayType, StructType, MapType)): + dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1) + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -734,6 +752,21 @@ class ArrayType(DataType): return obj return obj and [self.elementType.fromInternal(v) for v in obj] + def _build_formatted_string( + self, + prefix: str, + stringConcat: StringConcat, + maxDepth: int = _JVM_INT_MAX, + ) -> None: + if maxDepth > 0: + stringConcat.append( + f"{prefix}-- element: {self.elementType.typeName()} " + + f"(containsNull = {str(self.containsNull).lower()})\n" + ) + DataType._data_type_build_formatted_string( + self.elementType, f"{prefix} |", stringConcat, maxDepth + ) + class MapType(DataType): """Map data type. @@ -868,6 +901,25 @@ class MapType(DataType): (self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items() ) + def _build_formatted_string( + self, + prefix: str, + stringConcat: StringConcat, + maxDepth: int = _JVM_INT_MAX, + ) -> None: + if maxDepth > 0: + stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n") + DataType._data_type_build_formatted_string( + self.keyType, f"{prefix} |", stringConcat, maxDepth + ) + stringConcat.append( + f"{prefix}-- value: {self.valueType.typeName()} " + + f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n" + ) + DataType._data_type_build_formatted_string( + self.valueType, f"{prefix} |", stringConcat, maxDepth + ) + class StructField(DataType): """A field in :class:`StructType`. @@ -1016,6 +1068,21 @@ class StructField(DataType): message_parameters={}, ) + def _build_formatted_string( + self, + prefix: str, + stringConcat: StringConcat, + maxDepth: int = _JVM_INT_MAX, + ) -> None: + if maxDepth > 0: + stringConcat.append( + f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} " + + f"(nullable = {str(self.nullable).lower()})\n" + ) + DataType._data_type_build_formatted_string( + self.dataType, f"{prefix} |", stringConcat, maxDepth + ) + class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. @@ -1436,6 +1503,24 @@ class StructType(DataType): values = obj return _create_row(self.names, values) + def _build_formatted_string( + self, + prefix: str, + stringConcat: StringConcat, + maxDepth: int = _JVM_INT_MAX, + ) -> None: + for field in self.fields: + field._build_formatted_string(prefix, stringConcat, maxDepth) + + def treeString(self, maxDepth: int = _JVM_INT_MAX) -> str: + stringConcat = StringConcat() + stringConcat.append("root\n") + prefix = " |" + depth = maxDepth if maxDepth > 0 else _JVM_INT_MAX + for field in self.fields: + field._build_formatted_string(prefix, stringConcat, depth) + return stringConcat.toString() + class VariantType(AtomicType): """ diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index de7d58361c04..171f92e557a1 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -17,7 +17,18 @@ import inspect import functools import os -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type +from typing import ( + Any, + Callable, + Optional, + List, + Sequence, + TYPE_CHECKING, + cast, + TypeVar, + Union, + Type, +) # For backward compatibility. from pyspark.errors import ( # noqa: F401 @@ -124,6 +135,47 @@ class ForeachBatchFunction: implements = ["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"] +# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat' +_MAX_ROUNDED_ARRAY_LENGTH = (1 << 31) - 1 - 15 + + +class StringConcat: + def __init__(self, maxLength: int = _MAX_ROUNDED_ARRAY_LENGTH): + self.maxLength: int = maxLength + self.strings: List[str] = [] + self.length: int = 0 + + def atLimit(self) -> bool: + return self.length >= self.maxLength + + def append(self, s: str) -> None: + if s is not None: + sLen = len(s) + if not self.atLimit(): + available = self.maxLength - self.length + stringToAppend = s if available >= sLen else s[0:available] + self.strings.append(stringToAppend) + + self.length = min(self.length + sLen, _MAX_ROUNDED_ARRAY_LENGTH) + + def toString(self) -> str: + # finalLength = self.maxLength if self.atLimit() else self.length + return "".join(self.strings) + + +# Python implementation of 'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters' +def escape_meta_characters(s: str) -> str: + return ( + s.replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\f", "\\f") + .replace("\b", "\\b") + .replace("\u000B", "\\v") + .replace("\u0007", "\\a") + ) + + def to_str(value: Any) -> Optional[str]: """ A wrapper over str(), but converts bool values to lower case strings. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org