This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e55875b0bbe0 [SPARK-48372][SPARK-45716][PYTHON] Implement 
`StructType.treeString`
e55875b0bbe0 is described below

commit e55875b0bbe08c435ffcb0ea034ceb95938d8729
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 22 15:31:27 2024 +0800

    [SPARK-48372][SPARK-45716][PYTHON] Implement `StructType.treeString`
    
    ### What changes were proposed in this pull request?
    Implement `StructType.treeString`
    
    ### Why are the changes needed?
    feature parity, this method is Scala-only before
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ```
    In [2]: schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: 
STRUCT<c5: INT, c6: INT>>")
    
    In [3]: print(schema1.treeString())
    root
     |-- c1: integer (nullable = true)
     |-- c2: struct (nullable = true)
     |    |-- c3: integer (nullable = true)
     |    |-- c4: struct (nullable = true)
     |    |    |-- c5: integer (nullable = true)
     |    |    |-- c6: integer (nullable = true)
    ```
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46685 from zhengruifeng/py_tree_string.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_types.py | 241 +++++++++++++++++++++++++++++++++
 python/pyspark/sql/types.py            |  87 +++++++++++-
 python/pyspark/sql/utils.py            |  54 +++++++-
 3 files changed, 380 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 4d6fc499b70b..ec07406b1191 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1170,6 +1170,247 @@ class TypesTestsMixin:
         )
         self.assertEqual(VariantType(), _parse_datatype_string("variant"))
 
+    def test_tree_string(self):
+        schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: 
INT, c6: INT>>")
+
+        self.assertEqual(
+            schema1.treeString().split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                " |    |    |-- c5: integer (nullable = true)",
+                " |    |    |-- c6: integer (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(-1).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                " |    |    |-- c5: integer (nullable = true)",
+                " |    |    |-- c6: integer (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(0).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                " |    |    |-- c5: integer (nullable = true)",
+                " |    |    |-- c6: integer (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(1).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(2).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(3).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                " |    |    |-- c5: integer (nullable = true)",
+                " |    |    |-- c6: integer (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema1.treeString(4).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: struct (nullable = true)",
+                " |    |-- c3: integer (nullable = true)",
+                " |    |-- c4: struct (nullable = true)",
+                " |    |    |-- c5: integer (nullable = true)",
+                " |    |    |-- c6: integer (nullable = true)",
+                "",
+            ],
+        )
+
+        schema2 = DataType.fromDDL(
+            "c1 INT, c2 ARRAY<STRUCT<c3: INT>>, c4 STRUCT<c5: INT, c6: 
ARRAY<ARRAY<INT>>>"
+        )
+        self.assertEqual(
+            schema2.treeString(0).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: array (nullable = true)",
+                " |    |-- element: struct (containsNull = true)",
+                " |    |    |-- c3: integer (nullable = true)",
+                " |-- c4: struct (nullable = true)",
+                " |    |-- c5: integer (nullable = true)",
+                " |    |-- c6: array (nullable = true)",
+                " |    |    |-- element: array (containsNull = true)",
+                " |    |    |    |-- element: integer (containsNull = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema2.treeString(1).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: array (nullable = true)",
+                " |-- c4: struct (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema2.treeString(2).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: array (nullable = true)",
+                " |    |-- element: struct (containsNull = true)",
+                " |-- c4: struct (nullable = true)",
+                " |    |-- c5: integer (nullable = true)",
+                " |    |-- c6: array (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema2.treeString(3).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: array (nullable = true)",
+                " |    |-- element: struct (containsNull = true)",
+                " |    |    |-- c3: integer (nullable = true)",
+                " |-- c4: struct (nullable = true)",
+                " |    |-- c5: integer (nullable = true)",
+                " |    |-- c6: array (nullable = true)",
+                " |    |    |-- element: array (containsNull = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema2.treeString(4).split("\n"),
+            [
+                "root",
+                " |-- c1: integer (nullable = true)",
+                " |-- c2: array (nullable = true)",
+                " |    |-- element: struct (containsNull = true)",
+                " |    |    |-- c3: integer (nullable = true)",
+                " |-- c4: struct (nullable = true)",
+                " |    |-- c5: integer (nullable = true)",
+                " |    |-- c6: array (nullable = true)",
+                " |    |    |-- element: array (containsNull = true)",
+                " |    |    |    |-- element: integer (containsNull = true)",
+                "",
+            ],
+        )
+
+        schema3 = DataType.fromDDL(
+            "c1 MAP<INT, STRUCT<c2: MAP<INT, INT>>>, c3 STRUCT<c4: MAP<INT, 
MAP<INT, INT>>>"
+        )
+        self.assertEqual(
+            schema3.treeString(0).split("\n"),
+            [
+                "root",
+                " |-- c1: map (nullable = true)",
+                " |    |-- key: integer",
+                " |    |-- value: struct (valueContainsNull = true)",
+                " |    |    |-- c2: map (nullable = true)",
+                " |    |    |    |-- key: integer",
+                " |    |    |    |-- value: integer (valueContainsNull = 
true)",
+                " |-- c3: struct (nullable = true)",
+                " |    |-- c4: map (nullable = true)",
+                " |    |    |-- key: integer",
+                " |    |    |-- value: map (valueContainsNull = true)",
+                " |    |    |    |-- key: integer",
+                " |    |    |    |-- value: integer (valueContainsNull = 
true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema3.treeString(1).split("\n"),
+            [
+                "root",
+                " |-- c1: map (nullable = true)",
+                " |-- c3: struct (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema3.treeString(2).split("\n"),
+            [
+                "root",
+                " |-- c1: map (nullable = true)",
+                " |    |-- key: integer",
+                " |    |-- value: struct (valueContainsNull = true)",
+                " |-- c3: struct (nullable = true)",
+                " |    |-- c4: map (nullable = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema3.treeString(3).split("\n"),
+            [
+                "root",
+                " |-- c1: map (nullable = true)",
+                " |    |-- key: integer",
+                " |    |-- value: struct (valueContainsNull = true)",
+                " |    |    |-- c2: map (nullable = true)",
+                " |-- c3: struct (nullable = true)",
+                " |    |-- c4: map (nullable = true)",
+                " |    |    |-- key: integer",
+                " |    |    |-- value: map (valueContainsNull = true)",
+                "",
+            ],
+        )
+        self.assertEqual(
+            schema3.treeString(4).split("\n"),
+            [
+                "root",
+                " |-- c1: map (nullable = true)",
+                " |    |-- key: integer",
+                " |    |-- value: struct (valueContainsNull = true)",
+                " |    |    |-- c2: map (nullable = true)",
+                " |    |    |    |-- key: integer",
+                " |    |    |    |-- value: integer (valueContainsNull = 
true)",
+                " |-- c3: struct (nullable = true)",
+                " |    |-- c4: map (nullable = true)",
+                " |    |    |-- key: integer",
+                " |    |    |-- value: map (valueContainsNull = true)",
+                " |    |    |    |-- key: integer",
+                " |    |    |    |-- value: integer (valueContainsNull = 
true)",
+                "",
+            ],
+        )
+
     def test_metadata_null(self):
         schema = StructType(
             [
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 3f1e9ee83f10..fa98d09a9af9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -47,7 +47,12 @@ from typing import (
 
 from pyspark.util import is_remote_only
 from pyspark.serializers import CloudPickleSerializer
-from pyspark.sql.utils import has_numpy, get_active_spark_context
+from pyspark.sql.utils import (
+    has_numpy,
+    get_active_spark_context,
+    escape_meta_characters,
+    StringConcat,
+)
 from pyspark.sql.variant_utils import VariantUtils
 from pyspark.errors import (
     PySparkNotImplementedError,
@@ -99,6 +104,8 @@ __all__ = [
     "VariantVal",
 ]
 
+_JVM_INT_MAX: int = (1 << 31) - 1
+
 
 class DataType:
     """Base class for data types."""
@@ -199,6 +206,17 @@ class DataType:
         assert len(schema) == 1
         return schema[0].dataType
 
+    @classmethod
+    def _data_type_build_formatted_string(
+        cls,
+        dataType: "DataType",
+        prefix: str,
+        stringConcat: StringConcat,
+        maxDepth: int,
+    ) -> None:
+        if isinstance(dataType, (ArrayType, StructType, MapType)):
+            dataType._build_formatted_string(prefix, stringConcat, maxDepth - 
1)
+
 
 # This singleton pattern does not work with pickle, you will get
 # another object after pickle and unpickle
@@ -734,6 +752,21 @@ class ArrayType(DataType):
             return obj
         return obj and [self.elementType.fromInternal(v) for v in obj]
 
+    def _build_formatted_string(
+        self,
+        prefix: str,
+        stringConcat: StringConcat,
+        maxDepth: int = _JVM_INT_MAX,
+    ) -> None:
+        if maxDepth > 0:
+            stringConcat.append(
+                f"{prefix}-- element: {self.elementType.typeName()} "
+                + f"(containsNull = {str(self.containsNull).lower()})\n"
+            )
+            DataType._data_type_build_formatted_string(
+                self.elementType, f"{prefix}    |", stringConcat, maxDepth
+            )
+
 
 class MapType(DataType):
     """Map data type.
@@ -868,6 +901,25 @@ class MapType(DataType):
             (self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for 
k, v in obj.items()
         )
 
+    def _build_formatted_string(
+        self,
+        prefix: str,
+        stringConcat: StringConcat,
+        maxDepth: int = _JVM_INT_MAX,
+    ) -> None:
+        if maxDepth > 0:
+            stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
+            DataType._data_type_build_formatted_string(
+                self.keyType, f"{prefix}    |", stringConcat, maxDepth
+            )
+            stringConcat.append(
+                f"{prefix}-- value: {self.valueType.typeName()} "
+                + f"(valueContainsNull = 
{str(self.valueContainsNull).lower()})\n"
+            )
+            DataType._data_type_build_formatted_string(
+                self.valueType, f"{prefix}    |", stringConcat, maxDepth
+            )
+
 
 class StructField(DataType):
     """A field in :class:`StructType`.
@@ -1016,6 +1068,21 @@ class StructField(DataType):
             message_parameters={},
         )
 
+    def _build_formatted_string(
+        self,
+        prefix: str,
+        stringConcat: StringConcat,
+        maxDepth: int = _JVM_INT_MAX,
+    ) -> None:
+        if maxDepth > 0:
+            stringConcat.append(
+                f"{prefix}-- {escape_meta_characters(self.name)}: 
{self.dataType.typeName()} "
+                + f"(nullable = {str(self.nullable).lower()})\n"
+            )
+            DataType._data_type_build_formatted_string(
+                self.dataType, f"{prefix}    |", stringConcat, maxDepth
+            )
+
 
 class StructType(DataType):
     """Struct type, consisting of a list of :class:`StructField`.
@@ -1436,6 +1503,24 @@ class StructType(DataType):
             values = obj
         return _create_row(self.names, values)
 
+    def _build_formatted_string(
+        self,
+        prefix: str,
+        stringConcat: StringConcat,
+        maxDepth: int = _JVM_INT_MAX,
+    ) -> None:
+        for field in self.fields:
+            field._build_formatted_string(prefix, stringConcat, maxDepth)
+
+    def treeString(self, maxDepth: int = _JVM_INT_MAX) -> str:
+        stringConcat = StringConcat()
+        stringConcat.append("root\n")
+        prefix = " |"
+        depth = maxDepth if maxDepth > 0 else _JVM_INT_MAX
+        for field in self.fields:
+            field._build_formatted_string(prefix, stringConcat, depth)
+        return stringConcat.toString()
+
 
 class VariantType(AtomicType):
     """
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index de7d58361c04..171f92e557a1 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -17,7 +17,18 @@
 import inspect
 import functools
 import os
-from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, 
TypeVar, Union, Type
+from typing import (
+    Any,
+    Callable,
+    Optional,
+    List,
+    Sequence,
+    TYPE_CHECKING,
+    cast,
+    TypeVar,
+    Union,
+    Type,
+)
 
 # For backward compatibility.
 from pyspark.errors import (  # noqa: F401
@@ -124,6 +135,47 @@ class ForeachBatchFunction:
         implements = 
["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"]
 
 
+# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat'
+_MAX_ROUNDED_ARRAY_LENGTH = (1 << 31) - 1 - 15
+
+
+class StringConcat:
+    def __init__(self, maxLength: int = _MAX_ROUNDED_ARRAY_LENGTH):
+        self.maxLength: int = maxLength
+        self.strings: List[str] = []
+        self.length: int = 0
+
+    def atLimit(self) -> bool:
+        return self.length >= self.maxLength
+
+    def append(self, s: str) -> None:
+        if s is not None:
+            sLen = len(s)
+            if not self.atLimit():
+                available = self.maxLength - self.length
+                stringToAppend = s if available >= sLen else s[0:available]
+                self.strings.append(stringToAppend)
+
+            self.length = min(self.length + sLen, _MAX_ROUNDED_ARRAY_LENGTH)
+
+    def toString(self) -> str:
+        # finalLength = self.maxLength if self.atLimit()  else self.length
+        return "".join(self.strings)
+
+
+# Python implementation of 
'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters'
+def escape_meta_characters(s: str) -> str:
+    return (
+        s.replace("\n", "\\n")
+        .replace("\r", "\\r")
+        .replace("\t", "\\t")
+        .replace("\f", "\\f")
+        .replace("\b", "\\b")
+        .replace("\u000B", "\\v")
+        .replace("\u0007", "\\a")
+    )
+
+
 def to_str(value: Any) -> Optional[str]:
     """
     A wrapper over str(), but converts bool values to lower case strings.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to