This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 70902126ff0 [SPARK-42048][PYTHON][CONNECT] Fix the alias name for 
numpy literals
70902126ff0 is described below

commit 70902126ff09cbd6a7f4ce9e0f1c602a96936753
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Feb 20 09:08:13 2023 +0900

    [SPARK-42048][PYTHON][CONNECT] Fix the alias name for numpy literals
    
    ### What changes were proposed in this pull request?
    
    Fixes the alias name for numpy literals.
    
    Also fixes `F.lit` in Spark Connect to support `np.bool_` objects.
    
    ### Why are the changes needed?
    
    Currently the alias name for literals created from numpy scalars contains 
something like `CAST(` ... `AS <type>)`, but it should be removed and return 
only the value string as same as literals from Python numbers.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The alias name will be changed.
    
    ### How was this patch tested?
    
    Modifed/enabled related tests.
    
    Closes #40076 from ueshin/issues/SPARK-42048/lit.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit db9dbd90d8edff222636bebf25df2fb96adef534)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/expressions.py                 |  2 ++
 python/pyspark/sql/functions.py                           |  2 +-
 python/pyspark/sql/tests/connect/test_parity_functions.py |  5 -----
 python/pyspark/sql/tests/test_functions.py                | 15 ++++++++-------
 4 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 876748d06d8..76e4252dce7 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -281,6 +281,8 @@ class LiteralExpression(Expression):
                 dt = _from_numpy_type(value.dtype)
                 if dt is not None:
                     return dt
+                elif isinstance(value, np.bool_):
+                    return BooleanType()
             raise TypeError(f"Unsupported Data Type {type(value).__name__}")
 
     @classmethod
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b103af72e36..d296075fb0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -181,7 +181,7 @@ def lit(col: Any) -> Column:
         if has_numpy and isinstance(col, np.generic):
             dt = _from_numpy_type(col.dtype)
             if dt is not None:
-                return _invoke_function("lit", col).astype(dt)
+                return _invoke_function("lit", col).astype(dt).alias(str(col))
         return _invoke_function("lit", col)
 
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 1ea33d2e370..a69e47effe4 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -48,11 +48,6 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_lit_list(self):
         super().test_lit_list()
 
-    # TODO(SPARK-41283): Different column names of `lit(np.int8(1))`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_lit_np_scalar(self):
-        super().test_lit_np_scalar()
-
     def test_raise_error(self):
         self.check_raise_error(SparkConnectException)
 
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index d8343b4fb47..8bc2b96cc51 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1184,16 +1184,17 @@ class FunctionsTestsMixin:
         from pyspark.sql.functions import lit
 
         dtype_to_spark_dtypes = [
-            (np.int8, [("CAST(1 AS TINYINT)", "tinyint")]),
-            (np.int16, [("CAST(1 AS SMALLINT)", "smallint")]),
-            (np.int32, [("CAST(1 AS INT)", "int")]),
-            (np.int64, [("CAST(1 AS BIGINT)", "bigint")]),
-            (np.float32, [("CAST(1.0 AS FLOAT)", "float")]),
-            (np.float64, [("CAST(1.0 AS DOUBLE)", "double")]),
+            (np.int8, [("1", "tinyint")]),
+            (np.int16, [("1", "smallint")]),
+            (np.int32, [("1", "int")]),
+            (np.int64, [("1", "bigint")]),
+            (np.float32, [("1.0", "float")]),
+            (np.float64, [("1.0", "double")]),
             (np.bool_, [("true", "boolean")]),
         ]
         for dtype, spark_dtypes in dtype_to_spark_dtypes:
-            self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes, 
spark_dtypes)
+            with self.subTest(dtype):
+                
self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes, spark_dtypes)
 
     @unittest.skipIf(not have_numpy, "NumPy not installed")
     def test_np_scalar_input(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to