This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e97ad0a44419 [SPARK-48278][PYTHON][CONNECT] Refine the string 
representation of `Cast`
e97ad0a44419 is described below

commit e97ad0a444195a6f1db551fd652225973a517571
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 15 15:56:45 2024 +0800

    [SPARK-48278][PYTHON][CONNECT] Refine the string representation of `Cast`
    
    ### What changes were proposed in this pull request?
    Refine the string representation of `Cast`
    
    ### Why are the changes needed?
    try the best to make the string representation consistent with Spark Classic
    
    ### Does this PR introduce _any_ user-facing change?
    Spark Classic:
    ```
    In [1]: from pyspark.sql import functions as sf
    
    In [2]: sf.col("a").try_cast("int")
    Out[2]: Column<'TRY_CAST(a AS INT)'>
    ```
    
    Spark Connect, before this PR:
    ```
    In [1]: from pyspark.sql import functions as sf
    
    In [2]: sf.col("a").try_cast("int")
    Out[2]: Column<'(a (int))'>
    ```
    
    Spark Connect, after this PR:
    ```
    In [1]: from pyspark.sql import functions as sf
    
    In [2]: sf.col("a").try_cast("int")
    Out[2]: Column<'TRY_CAST(a AS INT)'>
    ```
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46585 from zhengruifeng/cast_str_repr.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/expressions.py | 14 +++++++++++++-
 python/pyspark/sql/tests/test_column.py   | 13 ++++++++++++-
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 92dde2f3670e..4dc54793ed81 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -848,6 +848,7 @@ class CastExpression(Expression):
     ) -> None:
         super().__init__()
         self._expr = expr
+        assert isinstance(data_type, (DataType, str))
         self._data_type = data_type
         if eval_mode is not None:
             assert isinstance(eval_mode, str)
@@ -873,7 +874,18 @@ class CastExpression(Expression):
         return fun
 
     def __repr__(self) -> str:
-        return f"({self._expr} ({self._data_type}))"
+        # We cannot guarantee the string representations be exactly the same, 
e.g.
+        # str(sf.col("a").cast("long")):
+        #   Column<'CAST(a AS BIGINT)'>     <- Spark Classic
+        #   Column<'CAST(a AS LONG)'>       <- Spark Connect
+        if isinstance(self._data_type, DataType):
+            str_data_type = self._data_type.simpleString().upper()
+        else:
+            str_data_type = str(self._data_type).upper()
+        if self._eval_mode is not None and self._eval_mode == "try":
+            return f"TRY_CAST({self._expr} AS {str_data_type})"
+        else:
+            return f"CAST({self._expr} AS {str_data_type})"
 
 
 class UnresolvedNamedLambdaVariable(Expression):
diff --git a/python/pyspark/sql/tests/test_column.py 
b/python/pyspark/sql/tests/test_column.py
index 6e5fcde57cab..8f6adb37b9d4 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -19,7 +19,7 @@
 from itertools import chain
 from pyspark.sql import Column, Row
 from pyspark.sql import functions as sf
-from pyspark.sql.types import StructType, StructField, LongType
+from pyspark.sql.types import StructType, StructField, IntegerType, LongType
 from pyspark.errors import AnalysisException, PySparkTypeError, 
PySparkValueError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
@@ -228,6 +228,17 @@ class ColumnTestsMixin:
             message_parameters={"arg_name": "metadata"},
         )
 
+    def test_cast_str_representation(self):
+        self.assertEqual(str(sf.col("a").cast("int")), "Column<'CAST(a AS 
INT)'>")
+        self.assertEqual(str(sf.col("a").cast("INT")), "Column<'CAST(a AS 
INT)'>")
+        self.assertEqual(str(sf.col("a").cast(IntegerType())), "Column<'CAST(a 
AS INT)'>")
+        self.assertEqual(str(sf.col("a").cast(LongType())), "Column<'CAST(a AS 
BIGINT)'>")
+
+        self.assertEqual(str(sf.col("a").try_cast("int")), "Column<'TRY_CAST(a 
AS INT)'>")
+        self.assertEqual(str(sf.col("a").try_cast("INT")), "Column<'TRY_CAST(a 
AS INT)'>")
+        self.assertEqual(str(sf.col("a").try_cast(IntegerType())), 
"Column<'TRY_CAST(a AS INT)'>")
+        self.assertEqual(str(sf.col("a").try_cast(LongType())), 
"Column<'TRY_CAST(a AS BIGINT)'>")
+
     def test_cast_negative(self):
         with self.assertRaises(PySparkTypeError) as pe:
             self.spark.range(1).id.cast(123)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to