This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c74f584481d9 [SPARK-48039][PYTHON][CONNECT] Update the error class for 
`group.apply`
c74f584481d9 is described below

commit c74f584481d9bcefda7e8ac2a37feb2d61891fe4
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Apr 29 20:06:22 2024 +0900

    [SPARK-48039][PYTHON][CONNECT] Update the error class for `group.apply`
    
    ### What changes were proposed in this pull request?
    Update the error class for `group.apply`
    
    ### Why are the changes needed?
    
https://github.com/apache/spark/commit/eae91ee3c96b6887581e59821d905b8ea94f6bc0 
introduced a dedicated error class `INVALID_UDF_EVAL_TYPE` for `group.apply`, 
but only used it in Spark Connect.
    
    This PR uses this error class in Spark Classic, to make it consistent. And 
also enable a parity test `GroupedApplyInPandasTests.test_wrong_args `
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46277 from zhengruifeng/fix_test_wrong_args.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/group_ops.py              | 10 ++++------
 .../tests/connect/test_parity_pandas_grouped_map.py |  4 ----
 .../sql/tests/pandas/test_pandas_grouped_map.py     | 21 +++++++++++----------
 3 files changed, 15 insertions(+), 20 deletions(-)

diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index d5b214e2f7d5..3d1c50d94902 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -18,7 +18,7 @@ import sys
 from typing import List, Union, TYPE_CHECKING, cast
 import warnings
 
-from pyspark.errors import PySparkValueError
+from pyspark.errors import PySparkTypeError
 from pyspark.util import PythonEvalType
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
@@ -100,11 +100,9 @@ class PandasGroupedOpsMixin:
                 != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
             )
         ):
-            raise PySparkValueError(
-                error_class="INVALID_PANDAS_UDF",
-                message_parameters={
-                    "detail": "the udf argument must be a pandas_udf of type 
GROUPED_MAP."
-                },
+            raise PySparkTypeError(
+                error_class="INVALID_UDF_EVAL_TYPE",
+                message_parameters={"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF"},
             )
 
         warnings.warn(
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
index f0e7eeb606ca..1cc4ce012623 100644
--- a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
@@ -30,10 +30,6 @@ class 
GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedConnectTes
     def test_wrong_return_type(self):
         super().test_wrong_return_type()
 
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_wrong_args(self):
-        super().test_wrong_args()
-
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_unsupported_types(self):
         super().test_unsupported_types()
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 0396006e2b36..f43dafc0a4a1 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -52,7 +52,7 @@ from pyspark.sql.types import (
     MapType,
     YearMonthIntervalType,
 )
-from pyspark.errors import PythonException, PySparkTypeError
+from pyspark.errors import PythonException, PySparkTypeError, PySparkValueError
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -421,22 +421,23 @@ class GroupedApplyInPandasTestsMixin:
     def check_wrong_args(self):
         df = self.data
 
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
             df.groupby("id").apply(lambda x: x)
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
             df.groupby("id").apply(udf(lambda x: x, DoubleType()))
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
             df.groupby("id").apply(sum(df.v))
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
             df.groupby("id").apply(df.v + 1)
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
+            df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType()))
+        with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
+            df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), 
PandasUDFType.SCALAR))
+
+        with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
             df.groupby("id").apply(
                 pandas_udf(lambda: 1, StructType([StructField("d", 
DoubleType())]))
             )
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
-            df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType()))
-        with self.assertRaisesRegex(ValueError, "Invalid 
function.*GROUPED_MAP"):
-            df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), 
PandasUDFType.SCALAR))
 
     def test_unsupported_types(self):
         with self.quiet():


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to