This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9bb15db85e53 [SPARK-48228][PYTHON][CONNECT] Implement the missing 
function validation in ApplyInXXX
9bb15db85e53 is described below

commit 9bb15db85e53b69b9c0ba112cd1dd93d8213eea4
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 9 22:01:13 2024 -0700

    [SPARK-48228][PYTHON][CONNECT] Implement the missing function validation in 
ApplyInXXX
    
    ### What changes were proposed in this pull request?
    Implement the missing function validation in ApplyInXXX
    
    https://github.com/apache/spark/pull/46397 fixed this issue for 
`Cogrouped.ApplyInPandas`, this PR fix remaining methods.
    
    ### Why are the changes needed?
    for better error message:
    
    ```
    In [12]: df1 = spark.range(11)
    
    In [13]: df2 = df1.groupby("id").applyInPandas(lambda: 1, 
StructType([StructField("d", DoubleType())]))
    
    In [14]: df2.show()
    ```
    
    before this PR, an invalid function causes weird execution errors:
    ```
    24/05/10 11:37:36 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 
36)
    org.apache.spark.api.python.PythonException: Traceback (most recent call 
last):
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 
1834, in main
        process()
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 
1826, in process
        serializer.dump_stream(out_iter, outfile)
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
 line 531, in dump_stream
        return ArrowStreamSerializer.dump_stream(self, 
init_stream_yield_batches(), stream)
               
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
 line 104, in dump_stream
        for batch in iterator:
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
 line 524, in init_stream_yield_batches
        for series in iterator:
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 
1610, in mapper
        return f(keys, vals)
               ^^^^^^^^^^^^^
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 
488, in <lambda>
        return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
                              ^^^^^^^^^^^^^
      File 
"/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 
483, in wrapped
        result, return_type, _assign_cols_by_name, truncate_return_schema=False
        ^^^^^^
    UnboundLocalError: cannot access local variable 'result' where it is not 
associated with a value
    
            at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:523)
            at 
org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
            at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:479)
            at 
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
            at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
            at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown
 Source)
            at 
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
            at 
org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
            at 
org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)
            at 
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:896)
    
            ...
    ```
    
    After this PR, the error happens before execution, which is consistent with 
Spark Classic, and
     much clear
    ```
    PySparkValueError: [INVALID_PANDAS_UDF] Invalid function: pandas_udf with 
function type GROUPED_MAP or the function in groupby.applyInPandas must take 
either one argument (data) or two arguments (key, data).
    
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, error message changes
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46519 from zhengruifeng/missing_check_in_group.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/connect/group.py                  |  8 ++++++--
 python/pyspark/sql/pandas/functions.py               |  4 ++--
 .../sql/tests/pandas/test_pandas_grouped_map.py      | 20 ++++++++++++++++++++
 3 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index c916e8acf3e4..2a5bb5939a3f 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -34,6 +34,7 @@ from typing import (
 from pyspark.util import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
 from pyspark.sql.pandas.group_ops import PandasCogroupedOps as 
PySparkPandasCogroupedOps
+from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
 from pyspark.sql.types import NumericType
 from pyspark.sql.types import StructType
 
@@ -293,6 +294,7 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
 
+        _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -322,6 +324,7 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
 
+        _validate_pandas_udf(func, 
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE)
         udf_obj = UserDefinedFunction(
             func,
             returnType=outputStructType,
@@ -360,6 +363,7 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
 
+        _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -398,9 +402,8 @@ class PandasCogroupedOps:
     ) -> "DataFrame":
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
-        from pyspark.sql.pandas.functions import _validate_pandas_udf  # type: 
ignore[attr-defined]
 
-        _validate_pandas_udf(func, schema, 
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
+        _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -426,6 +429,7 @@ class PandasCogroupedOps:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
 
+        _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 5922a5ced863..020105bb064a 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -432,7 +432,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
 
 
 # validate the pandas udf and return the adjusted eval type
-def _validate_pandas_udf(f, returnType, evalType) -> int:
+def _validate_pandas_udf(f, evalType) -> int:
     argspec = getfullargspec(f)
 
     # pandas UDF by type hints.
@@ -533,7 +533,7 @@ def _validate_pandas_udf(f, returnType, evalType) -> int:
 
 
 def _create_pandas_udf(f, returnType, evalType):
-    evalType = _validate_pandas_udf(f, returnType, evalType)
+    evalType = _validate_pandas_udf(f, evalType)
 
     if is_remote():
         from pyspark.sql.connect.udf import _create_udf as _create_connect_udf
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 1e86e12eb74f..a26d6d02a2bc 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -439,6 +439,26 @@ class GroupedApplyInPandasTestsMixin:
                 pandas_udf(lambda: 1, StructType([StructField("d", 
DoubleType())]))
             )
 
+    def test_wrong_args_in_apply_func(self):
+        df1 = self.spark.range(11)
+        df2 = self.spark.range(22)
+
+        with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+            df1.groupby("id").applyInPandas(lambda: 1, 
StructType([StructField("d", DoubleType())]))
+
+        with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+            df1.groupby("id").applyInArrow(lambda: 1, 
StructType([StructField("d", DoubleType())]))
+
+        with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+            df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
+                lambda: 1, StructType([StructField("d", DoubleType())])
+            )
+
+        with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
+            df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
+                lambda: 1, StructType([StructField("d", DoubleType())])
+            )
+
     def test_unsupported_types(self):
         with self.quiet():
             self.check_unsupported_types()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to