This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new c9bfcb9448b [SPARK-43967][PYTHON] Support regular Python UDTFs with empty return values c9bfcb9448b is described below commit c9bfcb9448b51985ad9a5361fbaaf828ad670cdc Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Tue Jul 18 14:59:09 2023 +0900 [SPARK-43967][PYTHON] Support regular Python UDTFs with empty return values ### What changes were proposed in this pull request? This PR adds support for regular (non-arrow-optimized) Python UDTFs that return empty results, for example: ``` def eval(self): ... ``` or ``` def eval(self): yield ``` This feature is already available in arrow-optimized UDTFs. ### Why are the changes needed? To align the behavior of regular Python UDTFs with arrow-optimized UDTFs. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can run regular Python UDTFs with empty return statement. ### How was this patch tested? Existing UTs. Closes #42044 from allisonwang-db/spark-43967-empty-return. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_udtf.py | 35 ++-------------------- python/pyspark/worker.py | 20 +++++++++++-- .../sql/execution/python/EvalPythonUDTFExec.scala | 11 ++++++- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index f109302dec5..ec3379accca 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -164,17 +164,14 @@ class BaseUDTFTestsMixin: def eval(self, a: int): ... - # TODO(SPARK-43967): Support Python UDTFs with empty return values - with self.assertRaisesRegex(PythonException, "TypeError"): - TestUDTF(lit(1)).collect() + self.assertEqual(TestUDTF(lit(1)).collect(), []) @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return - with self.assertRaisesRegex(PythonException, "TypeError"): - TestUDTF(lit(1)).collect() + self.assertEqual(TestUDTF(lit(1)).collect(), []) def test_udtf_with_conditional_return(self): class TestUDTF: @@ -195,9 +192,7 @@ class BaseUDTFTestsMixin: def eval(self, a: int): yield - # TODO(SPARK-43967): Support Python UDTFs with empty return values - with self.assertRaisesRegex(Py4JJavaError, "java.lang.NullPointerException"): - TestUDTF(lit(1)).collect() + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=None)]) def test_udtf_with_none_output(self): @udtf(returnType="a: int") @@ -807,21 +802,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): func = udtf(TestUDTF, returnType="a: int") self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) - def test_udtf_eval_with_no_return(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - ... - - self.assertEqual(TestUDTF(lit(1)).collect(), []) - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - return - - self.assertEqual(TestUDTF(lit(1)).collect(), []) - def test_udtf_terminate_with_wrong_num_output(self): # The error message for arrow-optimized UDTF is different from regular UDTF. err_msg = "The number of columns in the result does not match the specified schema." @@ -848,15 +828,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() - def test_udtf_with_empty_yield(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield - - # Arrow-optimized UDTF can support this. - self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=None)]) - def test_udtf_with_wrong_num_output(self): # The error message for arrow-optimized UDTF is different. err_msg = "The number of columns in the result does not match the specified schema." diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8c12312da27..2445b46970c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -579,7 +579,21 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal - return lambda *a: map(toInternal, f(*a)) + + # Evaluate the function and return a tuple back to the executor. + def evaluate(*a) -> tuple: + res = f(*a) + if res is None: + # If the function returns None or does not have an explicit return statement, + # an empty tuple is returned to the executor. + # This is because directly constructing tuple(None) results in an exception. + return tuple() + else: + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(toInternal, res)) + + return evaluate eval = wrap_udtf(getattr(udtf, "eval"), return_type) @@ -592,11 +606,11 @@ def read_udtf(pickleSer, infile, eval_type): def mapper(_, it): try: for a in it: - yield tuple(eval(*[a[o] for o in arg_offsets])) + yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: try: - yield tuple(terminate()) + yield terminate() except BaseException as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index 827b2fc2bb3..fab417a0f86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -103,6 +103,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { } val joined = new JoinedRow + val nullRow = new GenericInternalRow(udtf.elementSchema.length) val resultProj = UnsafeProjection.create(output, output) outputRowIterator.flatMap { outputRows => @@ -118,7 +119,15 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // from the UDTF are from the `terminate()` call. We leave the left side as the last // element of its child output to keep it consistent with the Generate implementation // and Hive UDTFs. - outputRows.map(r => resultProj(joined.withRight(r))) + outputRows.map { r => + // When the UDTF's result is None, such as `def eval(): yield`, + // we join it with a null row to avoid NullPointerException. + if (r == null) { + resultProj(joined.withRight(nullRow)) + } else { + resultProj(joined.withRight(r)) + } + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org