This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 14d3f447360b [SPARK-48395][PYTHON] Fix `StructType.treeString` for 
parameterized types
14d3f447360b is described below

commit 14d3f447360b66663c8979a8cdb4c40c480a1e04
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 23 16:12:38 2024 +0800

    [SPARK-48395][PYTHON] Fix `StructType.treeString` for parameterized types
    
    ### What changes were proposed in this pull request?
    this PR is a follow up of https://github.com/apache/spark/pull/46685.
    
    ### Why are the changes needed?
    `StructType.treeString` uses `DataType.typeName` to generate the tree 
string, however, the `typeName` in python is a class method and can not return 
the same string for parameterized types.
    
    ```
    In [2]: schema = StructType().add("c", CharType(10), True).add("v", 
VarcharType(10), True).add("d", DecimalType(10, 2), True).add("ym00", YearM
       ...: onthIntervalType(0, 0)).add("ym01", YearMonthIntervalType(0, 
1)).add("ym11", YearMonthIntervalType(1, 1))
    
    In [3]: print(schema.treeString())
    root
     |-- c: char (nullable = true)
     |-- v: varchar (nullable = true)
     |-- d: decimal (nullable = true)
     |-- ym00: yearmonthinterval (nullable = true)
     |-- ym01: yearmonthinterval (nullable = true)
     |-- ym11: yearmonthinterval (nullable = true)
    ```
    
    it should be
    ```
    In [4]: print(schema.treeString())
    root
     |-- c: char(10) (nullable = true)
     |-- v: varchar(10) (nullable = true)
     |-- d: decimal(10,2) (nullable = true)
     |-- ym00: interval year (nullable = true)
     |-- ym01: interval year to month (nullable = true)
     |-- ym11: interval month (nullable = true)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    no, this feature was just added and not release out yet.
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46711 from zhengruifeng/tree_string_fix.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_types.py | 67 ++++++++++++++++++++++++++++++++++
 python/pyspark/sql/types.py            | 27 ++++++++++++--
 2 files changed, 90 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index ec07406b1191..6c64a9471363 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -41,6 +41,7 @@ from pyspark.sql.types import (
     FloatType,
     DateType,
     TimestampType,
+    TimestampNTZType,
     DayTimeIntervalType,
     YearMonthIntervalType,
     CalendarIntervalType,
@@ -1411,6 +1412,72 @@ class TypesTestsMixin:
             ],
         )
 
+    def test_tree_string_for_builtin_types(self):
+        schema = (
+            StructType()
+            .add("n", NullType())
+            .add("str", StringType())
+            .add("c", CharType(10))
+            .add("v", VarcharType(10))
+            .add("bin", BinaryType())
+            .add("bool", BooleanType())
+            .add("date", DateType())
+            .add("ts", TimestampType())
+            .add("ts_ntz", TimestampNTZType())
+            .add("dec", DecimalType(10, 2))
+            .add("double", DoubleType())
+            .add("float", FloatType())
+            .add("long", LongType())
+            .add("int", IntegerType())
+            .add("short", ShortType())
+            .add("byte", ByteType())
+            .add("ym_interval_1", YearMonthIntervalType())
+            .add("ym_interval_2", 
YearMonthIntervalType(YearMonthIntervalType.YEAR))
+            .add(
+                "ym_interval_3",
+                YearMonthIntervalType(YearMonthIntervalType.YEAR, 
YearMonthIntervalType.MONTH),
+            )
+            .add("dt_interval_1", DayTimeIntervalType())
+            .add("dt_interval_2", DayTimeIntervalType(DayTimeIntervalType.DAY))
+            .add(
+                "dt_interval_3",
+                DayTimeIntervalType(DayTimeIntervalType.HOUR, 
DayTimeIntervalType.SECOND),
+            )
+            .add("cal_interval", CalendarIntervalType())
+            .add("var", VariantType())
+        )
+        self.assertEqual(
+            schema.treeString().split("\n"),
+            [
+                "root",
+                " |-- n: void (nullable = true)",
+                " |-- str: string (nullable = true)",
+                " |-- c: char(10) (nullable = true)",
+                " |-- v: varchar(10) (nullable = true)",
+                " |-- bin: binary (nullable = true)",
+                " |-- bool: boolean (nullable = true)",
+                " |-- date: date (nullable = true)",
+                " |-- ts: timestamp (nullable = true)",
+                " |-- ts_ntz: timestamp_ntz (nullable = true)",
+                " |-- dec: decimal(10,2) (nullable = true)",
+                " |-- double: double (nullable = true)",
+                " |-- float: float (nullable = true)",
+                " |-- long: long (nullable = true)",
+                " |-- int: integer (nullable = true)",
+                " |-- short: short (nullable = true)",
+                " |-- byte: byte (nullable = true)",
+                " |-- ym_interval_1: interval year to month (nullable = true)",
+                " |-- ym_interval_2: interval year (nullable = true)",
+                " |-- ym_interval_3: interval year to month (nullable = true)",
+                " |-- dt_interval_1: interval day to second (nullable = true)",
+                " |-- dt_interval_2: interval day (nullable = true)",
+                " |-- dt_interval_3: interval hour to second (nullable = 
true)",
+                " |-- cal_interval: interval (nullable = true)",
+                " |-- var: variant (nullable = true)",
+                "",
+            ],
+        )
+
     def test_metadata_null(self):
         schema = StructType(
             [
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ee0cc9db5c44..17b019240f82 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -215,6 +215,24 @@ class DataType:
         if isinstance(dataType, (ArrayType, StructType, MapType)):
             dataType._build_formatted_string(prefix, stringConcat, maxDepth - 
1)
 
+    # The method typeName() is not always the same as the Scala side.
+    # Add this helper method to make TreeString() compatible with Scala side.
+    @classmethod
+    def _get_jvm_type_name(cls, dataType: "DataType") -> str:
+        if isinstance(
+            dataType,
+            (
+                DecimalType,
+                CharType,
+                VarcharType,
+                DayTimeIntervalType,
+                YearMonthIntervalType,
+            ),
+        ):
+            return dataType.simpleString()
+        else:
+            return dataType.typeName()
+
 
 # This singleton pattern does not work with pickle, you will get
 # another object after pickle and unpickle
@@ -758,7 +776,7 @@ class ArrayType(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- element: {self.elementType.typeName()} "
+                f"{prefix}-- element: 
{DataType._get_jvm_type_name(self.elementType)} "
                 + f"(containsNull = {str(self.containsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -906,12 +924,12 @@ class MapType(DataType):
         maxDepth: int = JVM_INT_MAX,
     ) -> None:
         if maxDepth > 0:
-            stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
+            stringConcat.append(f"{prefix}-- key: 
{DataType._get_jvm_type_name(self.keyType)}\n")
             DataType._data_type_build_formatted_string(
                 self.keyType, f"{prefix}    |", stringConcat, maxDepth
             )
             stringConcat.append(
-                f"{prefix}-- value: {self.valueType.typeName()} "
+                f"{prefix}-- value: 
{DataType._get_jvm_type_name(self.valueType)} "
                 + f"(valueContainsNull = 
{str(self.valueContainsNull).lower()})\n"
             )
             DataType._data_type_build_formatted_string(
@@ -1074,7 +1092,8 @@ class StructField(DataType):
     ) -> None:
         if maxDepth > 0:
             stringConcat.append(
-                f"{prefix}-- {escape_meta_characters(self.name)}: 
{self.dataType.typeName()} "
+                f"{prefix}-- {escape_meta_characters(self.name)}: "
+                + f"{DataType._get_jvm_type_name(self.dataType)} "
                 + f"(nullable = {str(self.nullable).lower()})\n"
             )
             DataType._data_type_build_formatted_string(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to