This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 36b93d07eb9 [SPARK-44560][PYTHON][CONNECT] Improve tests and documentation for Arrow Python UDF 36b93d07eb9 is described below commit 36b93d07eb961905647c42fac80e22efdfb15f4f Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Thu Jul 27 13:45:05 2023 -0700 [SPARK-44560][PYTHON][CONNECT] Improve tests and documentation for Arrow Python UDF ### What changes were proposed in this pull request? - Test on complex return type - Remove complex return type constraints for Arrow Python UDF on Spark Connect - Update documentation of the related Spark conf The change targets both Spark 3.5 and 4.0. ### Why are the changes needed? Testability and parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #42178 from xinrong-meng/conf. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> (cherry picked from commit 5f6537409383e2dbdd699108f708567c37db8151) Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/connect/udf.py | 10 ++-------- python/pyspark/sql/tests/test_arrow_python_udf.py | 5 ----- python/pyspark/sql/tests/test_udf.py | 16 ++++++++++++++++ .../scala/org/apache/spark/sql/internal/SQLConf.scala | 3 +-- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 0a5d06618b3..2d7e423d3d5 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.expressions import ( ) from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType -from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType +from pyspark.sql.types import DataType, StringType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration from pyspark.errors import PySparkTypeError @@ -70,18 +70,12 @@ def _create_py_udf( is_arrow_enabled = useArrow regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) - return_type = regular_udf.returnType try: is_func_with_args = len(getfullargspec(f).args) > 0 except TypeError: is_func_with_args = False - is_output_atomic_type = ( - not isinstance(return_type, StructType) - and not isinstance(return_type, MapType) - and not isinstance(return_type, ArrayType) - ) if is_arrow_enabled: - if is_output_atomic_type and is_func_with_args: + if is_func_with_args: return _create_arrow_py_udf(regular_udf) else: warnings.warn( diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 264ea0b901f..f48f07666e1 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -47,11 +47,6 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin): def test_register_java_udaf(self): super(PythonUDFArrowTests, self).test_register_java_udaf() - # TODO(SPARK-43903): Standardize ArrayType conversion for Python UDF - @unittest.skip("Inconsistent ArrayType conversion with/without Arrow.") - def test_nested_array(self): - super(PythonUDFArrowTests, self).test_nested_array() - def test_complex_input_types(self): row = ( self.spark.range(1) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 8ffcb5e05a2..239ff27813b 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -882,6 +882,22 @@ class BaseUDFTestsMixin(object): row = df.select(f("nested_array")).first() self.assertEquals(row[0], [[1, 2], [3, 4], [4, 5]]) + def test_complex_return_types(self): + row = ( + self.spark.range(1) + .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map", "struct(1, 2) as struct") + .select( + udf(lambda x: x, "array<int>")("array"), + udf(lambda x: x, "map<string,string>")("map"), + udf(lambda x: x, "struct<col1:int,col2:int>")("struct"), + ) + .first() + ) + + self.assertEquals(row[0], [1, 2, 3]) + self.assertEquals(row[1], {"a": "b"}) + self.assertEquals(row[2], Row(col1=1, col2=2)) + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d4987e3443f..7fd960de37f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2917,8 +2917,7 @@ object SQLConf { val PYTHON_UDF_ARROW_ENABLED = buildConf("spark.sql.execution.pythonUDF.arrow.enabled") .doc("Enable Arrow optimization in regular Python UDFs. This optimization " + - "can only be enabled for atomic output types and input types except struct and map types " + - "when the given function takes at least one argument.") + "can only be enabled when the given function takes at least one argument.") .version("3.4.0") .booleanConf .createWithDefault(false) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org