This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 73c645584c5 [SPARK-43967][PYTHON] Support regular Python UDTFs with 
empty return values
73c645584c5 is described below

commit 73c645584c569e5edf7732b1bc3d58e816213d7a
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Tue Jul 18 14:59:09 2023 +0900

    [SPARK-43967][PYTHON] Support regular Python UDTFs with empty return values
    
    ### What changes were proposed in this pull request?
    This PR adds support for regular (non-arrow-optimized) Python UDTFs that 
return empty results, for example:
    ```
    def eval(self):
        ...
    ```
    or
    ```
    def eval(self):
        yield
    ```
    This feature is already available in arrow-optimized UDTFs.
    
    ### Why are the changes needed?
    To align the behavior of regular Python UDTFs with arrow-optimized UDTFs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. After this PR, users can run regular Python UDTFs with empty return 
statement.
    
    ### How was this patch tested?
    
    Existing UTs.
    
    Closes #42044 from allisonwang-db/spark-43967-empty-return.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_udtf.py              | 35 ++--------------------
 python/pyspark/worker.py                           | 20 +++++++++++--
 .../sql/execution/python/EvalPythonUDTFExec.scala  | 11 ++++++-
 3 files changed, 30 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index f109302dec5..ec3379accca 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -164,17 +164,14 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 ...
 
-        # TODO(SPARK-43967): Support Python UDTFs with empty return values
-        with self.assertRaisesRegex(PythonException, "TypeError"):
-            TestUDTF(lit(1)).collect()
+        self.assertEqual(TestUDTF(lit(1)).collect(), [])
 
         @udtf(returnType="a: int")
         class TestUDTF:
             def eval(self, a: int):
                 return
 
-        with self.assertRaisesRegex(PythonException, "TypeError"):
-            TestUDTF(lit(1)).collect()
+        self.assertEqual(TestUDTF(lit(1)).collect(), [])
 
     def test_udtf_with_conditional_return(self):
         class TestUDTF:
@@ -195,9 +192,7 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 yield
 
-        # TODO(SPARK-43967): Support Python UDTFs with empty return values
-        with self.assertRaisesRegex(Py4JJavaError, 
"java.lang.NullPointerException"):
-            TestUDTF(lit(1)).collect()
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=None)])
 
     def test_udtf_with_none_output(self):
         @udtf(returnType="a: int")
@@ -807,21 +802,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
         func = udtf(TestUDTF, returnType="a: int")
         self.assertEqual(func(lit(1)).collect(), [Row(a=1)])
 
-    def test_udtf_eval_with_no_return(self):
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                ...
-
-        self.assertEqual(TestUDTF(lit(1)).collect(), [])
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                return
-
-        self.assertEqual(TestUDTF(lit(1)).collect(), [])
-
     def test_udtf_terminate_with_wrong_num_output(self):
         # The error message for arrow-optimized UDTF is different from regular 
UDTF.
         err_msg = "The number of columns in the result does not match the 
specified schema."
@@ -848,15 +828,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
         with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).show()
 
-    def test_udtf_with_empty_yield(self):
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield
-
-        # Arrow-optimized UDTF can support this.
-        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=None)])
-
     def test_udtf_with_wrong_num_output(self):
         # The error message for arrow-optimized UDTF is different.
         err_msg = "The number of columns in the result does not match the 
specified schema."
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8d07772b214..06927a4a30b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -664,7 +664,21 @@ def read_udtf(pickleSer, infile, eval_type):
         def wrap_udtf(f, return_type):
             assert return_type.needConversion()
             toInternal = return_type.toInternal
-            return lambda *a: map(toInternal, f(*a))
+
+            # Evaluate the function and return a tuple back to the executor.
+            def evaluate(*a) -> tuple:
+                res = f(*a)
+                if res is None:
+                    # If the function returns None or does not have an 
explicit return statement,
+                    # an empty tuple is returned to the executor.
+                    # This is because directly constructing tuple(None) 
results in an exception.
+                    return tuple()
+                else:
+                    # If the function returns a result, we map it to the 
internal representation and
+                    # returns the results as a tuple.
+                    return tuple(map(toInternal, res))
+
+            return evaluate
 
         eval = wrap_udtf(getattr(udtf, "eval"), return_type)
 
@@ -677,11 +691,11 @@ def read_udtf(pickleSer, infile, eval_type):
         def mapper(_, it):
             try:
                 for a in it:
-                    yield tuple(eval(*[a[o] for o in arg_offsets]))
+                    yield eval(*[a[o] for o in arg_offsets])
             finally:
                 if terminate is not None:
                     try:
-                        yield tuple(terminate())
+                        yield terminate()
                     except BaseException as e:
                         raise PySparkRuntimeError(
                             error_class="UDTF_EXEC_ERROR",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
index 827b2fc2bb3..fab417a0f86 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
@@ -103,6 +103,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
         }
 
       val joined = new JoinedRow
+      val nullRow = new GenericInternalRow(udtf.elementSchema.length)
       val resultProj = UnsafeProjection.create(output, output)
 
       outputRowIterator.flatMap { outputRows =>
@@ -118,7 +119,15 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
         // from the UDTF are from the `terminate()` call. We leave the left 
side as the last
         // element of its child output to keep it consistent with the Generate 
implementation
         // and Hive UDTFs.
-        outputRows.map(r => resultProj(joined.withRight(r)))
+        outputRows.map { r =>
+          // When the UDTF's result is None, such as `def eval(): yield`,
+          // we join it with a null row to avoid NullPointerException.
+          if (r == null) {
+            resultProj(joined.withRight(nullRow))
+          } else {
+            resultProj(joined.withRight(r))
+          }
+        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to