This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 19fa431ef611 [SPARK-46300][PYTHON][CONNECT] Match minor behaviour 
matching in Column with full test coverage
19fa431ef611 is described below

commit 19fa431ef61181bd9bfe96a74f6d977b720d281e
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Dec 7 15:50:11 2023 +0900

    [SPARK-46300][PYTHON][CONNECT] Match minor behaviour matching in Column 
with full test coverage
    
    ### What changes were proposed in this pull request?
    
    This PR matches the corner case behaviours in `Column` between Spark 
Connect and non-Spark Connect with adding unittests with the full test coverage 
within `pyspark.sql.column`.
    
    ### Why are the changes needed?
    
    - For feature parity.
    - To improve the test coverage.
        See 
https://app.codecov.io/gh/apache/spark/commit/1a651753f4e760643d719add3b16acd311454c76/blob/python/pyspark/sql/column.py
    
    This is not being tested.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Manually ran the new unittest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44228 from HyukjinKwon/SPARK-46300.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/column.py                       | 16 +++++++++--
 python/pyspark/sql/connect/column.py               |  2 +-
 python/pyspark/sql/connect/expressions.py          |  5 ++++
 .../sql/tests/connect/test_connect_column.py       |  2 +-
 python/pyspark/sql/tests/test_column.py            | 32 +++++++++++++++++++++-
 python/pyspark/sql/tests/test_functions.py         | 14 +++++++++-
 python/pyspark/sql/tests/test_types.py             | 12 ++++++++
 7 files changed, 76 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 9357b4842bbd..198dd9ff3e40 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -75,7 +75,7 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject:
 
 @overload
 def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject:
-    pass
+    ...
 
 
 @overload
@@ -84,7 +84,7 @@ def _to_seq(
     cols: Iterable["ColumnOrName"],
     converter: Optional[Callable[["ColumnOrName"], JavaObject]],
 ) -> JavaObject:
-    pass
+    ...
 
 
 def _to_seq(
@@ -924,10 +924,20 @@ class Column:
 
         Examples
         --------
+
+        Example 1. Using integers for the input arguments.
+
         >>> df = spark.createDataFrame(
         ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
         >>> df.select(df.name.substr(1, 3).alias("col")).collect()
         [Row(col='Ali'), Row(col='Bob')]
+
+        Example 2. Using columns for the input arguments.
+
+        >>> df = spark.createDataFrame(
+        ...      [(3, 4, "Alice"), (2, 3, "Bob")], ["sidx", "eidx", "name"])
+        >>> df.select(df.name.substr(df.sidx, df.eidx).alias("col")).collect()
+        [Row(col='ice'), Row(col='ob')]
         """
         if type(startPos) != type(length):
             raise PySparkTypeError(
@@ -1199,7 +1209,7 @@ class Column:
             else:
                 return Column(getattr(self._jc, "as")(alias[0]))
         else:
-            if metadata:
+            if metadata is not None:
                 raise PySparkValueError(
                     error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
                     message_parameters={"arg_name": "metadata"},
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index a6d9ca8a2ff4..13b00fd83d8b 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -256,7 +256,7 @@ class Column:
         else:
             raise PySparkTypeError(
                 error_class="NOT_COLUMN_OR_INT",
-                message_parameters={"arg_name": "length", "arg_type": 
type(length).__name__},
+                message_parameters={"arg_name": "startPos", "arg_type": 
type(length).__name__},
             )
         return Column(UnresolvedFunction("substr", [self._expr, start_expr, 
length_expr]))
 
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 88c4f4d267b3..384422eed7d1 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -97,6 +97,11 @@ class Expression:
 
     def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
         metadata = kwargs.pop("metadata", None)
+        if len(alias) > 1 and metadata is not None:
+            raise PySparkValueError(
+                error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
+                message_parameters={"arg_name": "metadata"},
+            )
         assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
         return ColumnAlias(self, list(alias), metadata)
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index f9a9fa95a373..be351e133841 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -155,7 +155,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
             exception=pe.exception,
             error_class="NOT_COLUMN_OR_INT",
             message_parameters={
-                "arg_name": "length",
+                "arg_name": "startPos",
                 "arg_type": "float",
             },
         )
diff --git a/python/pyspark/sql/tests/test_column.py 
b/python/pyspark/sql/tests/test_column.py
index 622c1f7b2104..e51ae69814bd 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -20,7 +20,7 @@ from itertools import chain
 from pyspark.sql import Column, Row
 from pyspark.sql import functions as sf
 from pyspark.sql.types import StructType, StructField, LongType
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PySparkTypeError, 
PySparkValueError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -218,6 +218,36 @@ class ColumnTestsMixin:
         ).withColumn("square_value", mapping_expr[sf.col("key")])
         self.assertEqual(df.count(), 3)
 
+    def test_alias_negative(self):
+        with self.assertRaises(PySparkValueError) as pe:
+            self.spark.range(1).id.alias("a", "b", metadata={})
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
+            message_parameters={"arg_name": "metadata"},
+        )
+
+    def test_cast_negative(self):
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.range(1).id.cast(123)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_DATATYPE_OR_STR",
+            message_parameters={"arg_name": "dataType", "arg_type": "int"},
+        )
+
+    def test_over_negative(self):
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.range(1).id.over(123)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_WINDOWSPEC",
+            message_parameters={"arg_name": "window", "arg_type": "int"},
+        )
+
 
 class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 2bdcfa6085fd..2ac7ddbcba59 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -346,7 +346,7 @@ class FunctionsTestsMixin:
 
         df = self.spark.createDataFrame([["nick"]], schema=["name"])
         with self.assertRaises(PySparkTypeError) as pe:
-            df.select(F.col("name").substr(0, F.lit(1)))
+            F.col("name").substr(0, F.lit(1))
 
         self.check_error(
             exception=pe.exception,
@@ -359,6 +359,18 @@ class FunctionsTestsMixin:
             },
         )
 
+        with self.assertRaises(PySparkTypeError) as pe:
+            F.col("name").substr("", "")
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_COLUMN_OR_INT",
+            message_parameters={
+                "arg_name": "startPos",
+                "arg_type": "str",
+            },
+        )
+
         for name in string_functions:
             self.assertEqual(
                 df.select(getattr(F, name)("name")).first()[0],
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 06064e58c794..992abc8e82d9 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -883,6 +883,18 @@ class TypesTestsMixin:
         self.assertEqual("v", df.select(df.d["k"]).first()[0])
         self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
 
+        # Deprecated behaviors
+        map_col = F.create_map(F.lit(0), F.lit(100), F.lit(1), F.lit(200))
+        self.assertEqual(
+            self.spark.range(1).withColumn("mapped", 
map_col.getItem(F.col("id"))).first()[1], 100
+        )
+
+        struct_col = F.struct(F.lit(0), F.lit(100), F.lit(1), F.lit(200))
+        self.assertEqual(
+            self.spark.range(1).withColumn("struct", 
struct_col.getField(F.lit("col1"))).first()[1],
+            0,
+        )
+
     def test_infer_long_type(self):
         longrow = [Row(f1="a", f2=100000000000000)]
         df = self.sc.parallelize(longrow).toDF()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to