This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 19fa431ef611 [SPARK-46300][PYTHON][CONNECT] Match minor behaviour matching in Column with full test coverage 19fa431ef611 is described below commit 19fa431ef61181bd9bfe96a74f6d977b720d281e Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Dec 7 15:50:11 2023 +0900 [SPARK-46300][PYTHON][CONNECT] Match minor behaviour matching in Column with full test coverage ### What changes were proposed in this pull request? This PR matches the corner case behaviours in `Column` between Spark Connect and non-Spark Connect with adding unittests with the full test coverage within `pyspark.sql.column`. ### Why are the changes needed? - For feature parity. - To improve the test coverage. See https://app.codecov.io/gh/apache/spark/commit/1a651753f4e760643d719add3b16acd311454c76/blob/python/pyspark/sql/column.py This is not being tested. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually ran the new unittest. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44228 from HyukjinKwon/SPARK-46300. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/column.py | 16 +++++++++-- python/pyspark/sql/connect/column.py | 2 +- python/pyspark/sql/connect/expressions.py | 5 ++++ .../sql/tests/connect/test_connect_column.py | 2 +- python/pyspark/sql/tests/test_column.py | 32 +++++++++++++++++++++- python/pyspark/sql/tests/test_functions.py | 14 +++++++++- python/pyspark/sql/tests/test_types.py | 12 ++++++++ 7 files changed, 76 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9357b4842bbd..198dd9ff3e40 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -75,7 +75,7 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject: @overload def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject: - pass + ... @overload @@ -84,7 +84,7 @@ def _to_seq( cols: Iterable["ColumnOrName"], converter: Optional[Callable[["ColumnOrName"], JavaObject]], ) -> JavaObject: - pass + ... def _to_seq( @@ -924,10 +924,20 @@ class Column: Examples -------- + + Example 1. Using integers for the input arguments. + >>> df = spark.createDataFrame( ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col='Ali'), Row(col='Bob')] + + Example 2. Using columns for the input arguments. + + >>> df = spark.createDataFrame( + ... [(3, 4, "Alice"), (2, 3, "Bob")], ["sidx", "eidx", "name"]) + >>> df.select(df.name.substr(df.sidx, df.eidx).alias("col")).collect() + [Row(col='ice'), Row(col='ob')] """ if type(startPos) != type(length): raise PySparkTypeError( @@ -1199,7 +1209,7 @@ class Column: else: return Column(getattr(self._jc, "as")(alias[0])) else: - if metadata: + if metadata is not None: raise PySparkValueError( error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN", message_parameters={"arg_name": "metadata"}, diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index a6d9ca8a2ff4..13b00fd83d8b 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -256,7 +256,7 @@ class Column: else: raise PySparkTypeError( error_class="NOT_COLUMN_OR_INT", - message_parameters={"arg_name": "length", "arg_type": type(length).__name__}, + message_parameters={"arg_name": "startPos", "arg_type": type(length).__name__}, ) return Column(UnresolvedFunction("substr", [self._expr, start_expr, length_expr])) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 88c4f4d267b3..384422eed7d1 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -97,6 +97,11 @@ class Expression: def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias": metadata = kwargs.pop("metadata", None) + if len(alias) > 1 and metadata is not None: + raise PySparkValueError( + error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN", + message_parameters={"arg_name": "metadata"}, + ) assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs return ColumnAlias(self, list(alias), metadata) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index f9a9fa95a373..be351e133841 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -155,7 +155,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase): exception=pe.exception, error_class="NOT_COLUMN_OR_INT", message_parameters={ - "arg_name": "length", + "arg_name": "startPos", "arg_type": "float", }, ) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 622c1f7b2104..e51ae69814bd 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -20,7 +20,7 @@ from itertools import chain from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, LongType -from pyspark.errors import AnalysisException, PySparkTypeError +from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -218,6 +218,36 @@ class ColumnTestsMixin: ).withColumn("square_value", mapping_expr[sf.col("key")]) self.assertEqual(df.count(), 3) + def test_alias_negative(self): + with self.assertRaises(PySparkValueError) as pe: + self.spark.range(1).id.alias("a", "b", metadata={}) + + self.check_error( + exception=pe.exception, + error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN", + message_parameters={"arg_name": "metadata"}, + ) + + def test_cast_negative(self): + with self.assertRaises(PySparkTypeError) as pe: + self.spark.range(1).id.cast(123) + + self.check_error( + exception=pe.exception, + error_class="NOT_DATATYPE_OR_STR", + message_parameters={"arg_name": "dataType", "arg_type": "int"}, + ) + + def test_over_negative(self): + with self.assertRaises(PySparkTypeError) as pe: + self.spark.range(1).id.over(123) + + self.check_error( + exception=pe.exception, + error_class="NOT_WINDOWSPEC", + message_parameters={"arg_name": "window", "arg_type": "int"}, + ) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 2bdcfa6085fd..2ac7ddbcba59 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -346,7 +346,7 @@ class FunctionsTestsMixin: df = self.spark.createDataFrame([["nick"]], schema=["name"]) with self.assertRaises(PySparkTypeError) as pe: - df.select(F.col("name").substr(0, F.lit(1))) + F.col("name").substr(0, F.lit(1)) self.check_error( exception=pe.exception, @@ -359,6 +359,18 @@ class FunctionsTestsMixin: }, ) + with self.assertRaises(PySparkTypeError) as pe: + F.col("name").substr("", "") + + self.check_error( + exception=pe.exception, + error_class="NOT_COLUMN_OR_INT", + message_parameters={ + "arg_name": "startPos", + "arg_type": "str", + }, + ) + for name in string_functions: self.assertEqual( df.select(getattr(F, name)("name")).first()[0], diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 06064e58c794..992abc8e82d9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -883,6 +883,18 @@ class TypesTestsMixin: self.assertEqual("v", df.select(df.d["k"]).first()[0]) self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + # Deprecated behaviors + map_col = F.create_map(F.lit(0), F.lit(100), F.lit(1), F.lit(200)) + self.assertEqual( + self.spark.range(1).withColumn("mapped", map_col.getItem(F.col("id"))).first()[1], 100 + ) + + struct_col = F.struct(F.lit(0), F.lit(100), F.lit(1), F.lit(200)) + self.assertEqual( + self.spark.range(1).withColumn("struct", struct_col.getField(F.lit("col1"))).first()[1], + 0, + ) + def test_infer_long_type(self): longrow = [Row(f1="a", f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org