This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 36b93d07eb9 [SPARK-44560][PYTHON][CONNECT] Improve tests and 
documentation for Arrow Python UDF
36b93d07eb9 is described below

commit 36b93d07eb961905647c42fac80e22efdfb15f4f
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Thu Jul 27 13:45:05 2023 -0700

    [SPARK-44560][PYTHON][CONNECT] Improve tests and documentation for Arrow 
Python UDF
    
    ### What changes were proposed in this pull request?
    
    - Test on complex return type
    - Remove complex return type constraints for Arrow Python UDF on Spark 
Connect
    - Update documentation of the related Spark conf
    
    The change targets both Spark 3.5 and 4.0.
    
    ### Why are the changes needed?
    Testability and parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #42178 from xinrong-meng/conf.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
    (cherry picked from commit 5f6537409383e2dbdd699108f708567c37db8151)
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 python/pyspark/sql/connect/udf.py                        | 10 ++--------
 python/pyspark/sql/tests/test_arrow_python_udf.py        |  5 -----
 python/pyspark/sql/tests/test_udf.py                     | 16 ++++++++++++++++
 .../scala/org/apache/spark/sql/internal/SQLConf.scala    |  3 +--
 4 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 0a5d06618b3..2d7e423d3d5 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.expressions import (
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.types import UnparsedDataType
-from pyspark.sql.types import ArrayType, DataType, MapType, StringType, 
StructType
+from pyspark.sql.types import DataType, StringType
 from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
 from pyspark.errors import PySparkTypeError
 
@@ -70,18 +70,12 @@ def _create_py_udf(
         is_arrow_enabled = useArrow
 
     regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
-    return_type = regular_udf.returnType
     try:
         is_func_with_args = len(getfullargspec(f).args) > 0
     except TypeError:
         is_func_with_args = False
-    is_output_atomic_type = (
-        not isinstance(return_type, StructType)
-        and not isinstance(return_type, MapType)
-        and not isinstance(return_type, ArrayType)
-    )
     if is_arrow_enabled:
-        if is_output_atomic_type and is_func_with_args:
+        if is_func_with_args:
             return _create_arrow_py_udf(regular_udf)
         else:
             warnings.warn(
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 264ea0b901f..f48f07666e1 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -47,11 +47,6 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
     def test_register_java_udaf(self):
         super(PythonUDFArrowTests, self).test_register_java_udaf()
 
-    # TODO(SPARK-43903): Standardize ArrayType conversion for Python UDF
-    @unittest.skip("Inconsistent ArrayType conversion with/without Arrow.")
-    def test_nested_array(self):
-        super(PythonUDFArrowTests, self).test_nested_array()
-
     def test_complex_input_types(self):
         row = (
             self.spark.range(1)
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 8ffcb5e05a2..239ff27813b 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -882,6 +882,22 @@ class BaseUDFTestsMixin(object):
         row = df.select(f("nested_array")).first()
         self.assertEquals(row[0], [[1, 2], [3, 4], [4, 5]])
 
+    def test_complex_return_types(self):
+        row = (
+            self.spark.range(1)
+            .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map", 
"struct(1, 2) as struct")
+            .select(
+                udf(lambda x: x, "array<int>")("array"),
+                udf(lambda x: x, "map<string,string>")("map"),
+                udf(lambda x: x, "struct<col1:int,col2:int>")("struct"),
+            )
+            .first()
+        )
+
+        self.assertEquals(row[0], [1, 2, 3])
+        self.assertEquals(row[1], {"a": "b"})
+        self.assertEquals(row[2], Row(col1=1, col2=2))
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d4987e3443f..7fd960de37f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2917,8 +2917,7 @@ object SQLConf {
   val PYTHON_UDF_ARROW_ENABLED =
     buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
       .doc("Enable Arrow optimization in regular Python UDFs. This 
optimization " +
-        "can only be enabled for atomic output types and input types except 
struct and map types " +
-        "when the given function takes at least one argument.")
+        "can only be enabled when the given function takes at least one 
argument.")
       .version("3.4.0")
       .booleanConf
       .createWithDefault(false)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to