This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e97ad0a44419 [SPARK-48278][PYTHON][CONNECT] Refine the string representation of `Cast` e97ad0a44419 is described below commit e97ad0a444195a6f1db551fd652225973a517571 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed May 15 15:56:45 2024 +0800 [SPARK-48278][PYTHON][CONNECT] Refine the string representation of `Cast` ### What changes were proposed in this pull request? Refine the string representation of `Cast` ### Why are the changes needed? try the best to make the string representation consistent with Spark Classic ### Does this PR introduce _any_ user-facing change? Spark Classic: ``` In [1]: from pyspark.sql import functions as sf In [2]: sf.col("a").try_cast("int") Out[2]: Column<'TRY_CAST(a AS INT)'> ``` Spark Connect, before this PR: ``` In [1]: from pyspark.sql import functions as sf In [2]: sf.col("a").try_cast("int") Out[2]: Column<'(a (int))'> ``` Spark Connect, after this PR: ``` In [1]: from pyspark.sql import functions as sf In [2]: sf.col("a").try_cast("int") Out[2]: Column<'TRY_CAST(a AS INT)'> ``` ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46585 from zhengruifeng/cast_str_repr. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/expressions.py | 14 +++++++++++++- python/pyspark/sql/tests/test_column.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 92dde2f3670e..4dc54793ed81 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -848,6 +848,7 @@ class CastExpression(Expression): ) -> None: super().__init__() self._expr = expr + assert isinstance(data_type, (DataType, str)) self._data_type = data_type if eval_mode is not None: assert isinstance(eval_mode, str) @@ -873,7 +874,18 @@ class CastExpression(Expression): return fun def __repr__(self) -> str: - return f"({self._expr} ({self._data_type}))" + # We cannot guarantee the string representations be exactly the same, e.g. + # str(sf.col("a").cast("long")): + # Column<'CAST(a AS BIGINT)'> <- Spark Classic + # Column<'CAST(a AS LONG)'> <- Spark Connect + if isinstance(self._data_type, DataType): + str_data_type = self._data_type.simpleString().upper() + else: + str_data_type = str(self._data_type).upper() + if self._eval_mode is not None and self._eval_mode == "try": + return f"TRY_CAST({self._expr} AS {str_data_type})" + else: + return f"CAST({self._expr} AS {str_data_type})" class UnresolvedNamedLambdaVariable(Expression): diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 6e5fcde57cab..8f6adb37b9d4 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -19,7 +19,7 @@ from itertools import chain from pyspark.sql import Column, Row from pyspark.sql import functions as sf -from pyspark.sql.types import StructType, StructField, LongType +from pyspark.sql.types import StructType, StructField, IntegerType, LongType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -228,6 +228,17 @@ class ColumnTestsMixin: message_parameters={"arg_name": "metadata"}, ) + def test_cast_str_representation(self): + self.assertEqual(str(sf.col("a").cast("int")), "Column<'CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").cast("INT")), "Column<'CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").cast(IntegerType())), "Column<'CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").cast(LongType())), "Column<'CAST(a AS BIGINT)'>") + + self.assertEqual(str(sf.col("a").try_cast("int")), "Column<'TRY_CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").try_cast("INT")), "Column<'TRY_CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").try_cast(IntegerType())), "Column<'TRY_CAST(a AS INT)'>") + self.assertEqual(str(sf.col("a").try_cast(LongType())), "Column<'TRY_CAST(a AS BIGINT)'>") + def test_cast_negative(self): with self.assertRaises(PySparkTypeError) as pe: self.spark.range(1).id.cast(123) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org