This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9bb15db85e53 [SPARK-48228][PYTHON][CONNECT] Implement the missing function validation in ApplyInXXX 9bb15db85e53 is described below commit 9bb15db85e53b69b9c0ba112cd1dd93d8213eea4 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu May 9 22:01:13 2024 -0700 [SPARK-48228][PYTHON][CONNECT] Implement the missing function validation in ApplyInXXX ### What changes were proposed in this pull request? Implement the missing function validation in ApplyInXXX https://github.com/apache/spark/pull/46397 fixed this issue for `Cogrouped.ApplyInPandas`, this PR fix remaining methods. ### Why are the changes needed? for better error message: ``` In [12]: df1 = spark.range(11) In [13]: df2 = df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())])) In [14]: df2.show() ``` before this PR, an invalid function causes weird execution errors: ``` 24/05/10 11:37:36 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 36) org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1834, in main process() File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1826, in process serializer.dump_stream(out_iter, outfile) File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 531, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 104, in dump_stream for batch in iterator: File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 524, in init_stream_yield_batches for series in iterator: File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1610, in mapper return f(keys, vals) ^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 488, in <lambda> return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] ^^^^^^^^^^^^^ File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 483, in wrapped result, return_type, _assign_cols_by_name, truncate_return_schema=False ^^^^^^ UnboundLocalError: cannot access local variable 'result' where it is not associated with a value at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:523) at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117) at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:479) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601) at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:896) ... ``` After this PR, the error happens before execution, which is consistent with Spark Classic, and much clear ``` PySparkValueError: [INVALID_PANDAS_UDF] Invalid function: pandas_udf with function type GROUPED_MAP or the function in groupby.applyInPandas must take either one argument (data) or two arguments (key, data). ``` ### Does this PR introduce _any_ user-facing change? yes, error message changes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46519 from zhengruifeng/missing_check_in_group. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/sql/connect/group.py | 8 ++++++-- python/pyspark/sql/pandas/functions.py | 4 ++-- .../sql/tests/pandas/test_pandas_grouped_map.py | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index c916e8acf3e4..2a5bb5939a3f 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -34,6 +34,7 @@ from typing import ( from pyspark.util import PythonEvalType from pyspark.sql.group import GroupedData as PySparkGroupedData from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps +from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] from pyspark.sql.types import NumericType from pyspark.sql.types import StructType @@ -293,6 +294,7 @@ class GroupedData: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -322,6 +324,7 @@ class GroupedData: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE) udf_obj = UserDefinedFunction( func, returnType=outputStructType, @@ -360,6 +363,7 @@ class GroupedData: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -398,9 +402,8 @@ class PandasCogroupedOps: ) -> "DataFrame": from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] - _validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) + _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, @@ -426,6 +429,7 @@ class PandasCogroupedOps: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 5922a5ced863..020105bb064a 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -432,7 +432,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): # validate the pandas udf and return the adjusted eval type -def _validate_pandas_udf(f, returnType, evalType) -> int: +def _validate_pandas_udf(f, evalType) -> int: argspec = getfullargspec(f) # pandas UDF by type hints. @@ -533,7 +533,7 @@ def _validate_pandas_udf(f, returnType, evalType) -> int: def _create_pandas_udf(f, returnType, evalType): - evalType = _validate_pandas_udf(f, returnType, evalType) + evalType = _validate_pandas_udf(f, evalType) if is_remote(): from pyspark.sql.connect.udf import _create_udf as _create_connect_udf diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 1e86e12eb74f..a26d6d02a2bc 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -439,6 +439,26 @@ class GroupedApplyInPandasTestsMixin: pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])) ) + def test_wrong_args_in_apply_func(self): + df1 = self.spark.range(11) + df2 = self.spark.range(22) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())])) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").applyInArrow(lambda: 1, StructType([StructField("d", DoubleType())])) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + lambda: 1, StructType([StructField("d", DoubleType())]) + ) + + with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"): + df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow( + lambda: 1, StructType([StructField("d", DoubleType())]) + ) + def test_unsupported_types(self): with self.quiet(): self.check_unsupported_types() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org