This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 70902126ff0 [SPARK-42048][PYTHON][CONNECT] Fix the alias name for numpy literals 70902126ff0 is described below commit 70902126ff09cbd6a7f4ce9e0f1c602a96936753 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Feb 20 09:08:13 2023 +0900 [SPARK-42048][PYTHON][CONNECT] Fix the alias name for numpy literals ### What changes were proposed in this pull request? Fixes the alias name for numpy literals. Also fixes `F.lit` in Spark Connect to support `np.bool_` objects. ### Why are the changes needed? Currently the alias name for literals created from numpy scalars contains something like `CAST(` ... `AS <type>)`, but it should be removed and return only the value string as same as literals from Python numbers. ### Does this PR introduce _any_ user-facing change? The alias name will be changed. ### How was this patch tested? Modifed/enabled related tests. Closes #40076 from ueshin/issues/SPARK-42048/lit. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit db9dbd90d8edff222636bebf25df2fb96adef534) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/expressions.py | 2 ++ python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/tests/connect/test_parity_functions.py | 5 ----- python/pyspark/sql/tests/test_functions.py | 15 ++++++++------- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 876748d06d8..76e4252dce7 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -281,6 +281,8 @@ class LiteralExpression(Expression): dt = _from_numpy_type(value.dtype) if dt is not None: return dt + elif isinstance(value, np.bool_): + return BooleanType() raise TypeError(f"Unsupported Data Type {type(value).__name__}") @classmethod diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b103af72e36..d296075fb0b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -181,7 +181,7 @@ def lit(col: Any) -> Column: if has_numpy and isinstance(col, np.generic): dt = _from_numpy_type(col.dtype) if dt is not None: - return _invoke_function("lit", col).astype(dt) + return _invoke_function("lit", col).astype(dt).alias(str(col)) return _invoke_function("lit", col) diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index 1ea33d2e370..a69e47effe4 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -48,11 +48,6 @@ class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase): def test_lit_list(self): super().test_lit_list() - # TODO(SPARK-41283): Different column names of `lit(np.int8(1))` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_lit_np_scalar(self): - super().test_lit_np_scalar() - def test_raise_error(self): self.check_raise_error(SparkConnectException) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index d8343b4fb47..8bc2b96cc51 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1184,16 +1184,17 @@ class FunctionsTestsMixin: from pyspark.sql.functions import lit dtype_to_spark_dtypes = [ - (np.int8, [("CAST(1 AS TINYINT)", "tinyint")]), - (np.int16, [("CAST(1 AS SMALLINT)", "smallint")]), - (np.int32, [("CAST(1 AS INT)", "int")]), - (np.int64, [("CAST(1 AS BIGINT)", "bigint")]), - (np.float32, [("CAST(1.0 AS FLOAT)", "float")]), - (np.float64, [("CAST(1.0 AS DOUBLE)", "double")]), + (np.int8, [("1", "tinyint")]), + (np.int16, [("1", "smallint")]), + (np.int32, [("1", "int")]), + (np.int64, [("1", "bigint")]), + (np.float32, [("1.0", "float")]), + (np.float64, [("1.0", "double")]), (np.bool_, [("true", "boolean")]), ] for dtype, spark_dtypes in dtype_to_spark_dtypes: - self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes, spark_dtypes) + with self.subTest(dtype): + self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes, spark_dtypes) @unittest.skipIf(not have_numpy, "NumPy not installed") def test_np_scalar_input(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org