This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 1a01a492c05 [SPARK-40121][PYTHON][SQL] Initialize projection used for 
Python UDF
1a01a492c05 is described below

commit 1a01a492c051bb861c480f224a3c310e133e4d01
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Aug 18 12:23:02 2022 +0900

    [SPARK-40121][PYTHON][SQL] Initialize projection used for Python UDF
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to initialize the projection so non-deterministic 
expressions can be evaluated with Python UDFs.
    
    ### Why are the changes needed?
    
    To make the Python UDF working with non-deterministic expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ```python
    from pyspark.sql.functions import udf, rand
    spark.range(10).select(udf(lambda x: x, "double")(rand())).show()
    ```
    
    **Before**
    
    ```
    java.lang.NullPointerException
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
 Source)
            at 
org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$10(EvalPythonExec.scala:126)
            at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
            at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
            at 
scala.collection.Iterator$GroupedIterator.takeDestructively(Iterator.scala:1161)
            at scala.collection.Iterator$GroupedIterator.go(Iterator.scala:1176)
            at 
scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1213)
    ```
    
    **After**
    
    ```
    +----------------------------------+
    |<lambda>rand(-2507211707257730645)|
    +----------------------------------+
    |                0.7691724424045242|
    |               0.09602244075319044|
    |                0.3006471278112862|
    |                0.4182649571961977|
    |               0.29349096650900974|
    |                0.7987097908937618|
    |                0.5324802583101007|
    |                  0.72460930912789|
    |                0.1367749768412846|
    |               0.17277322931919348|
    +----------------------------------+
    ```
    
    ### How was this patch tested?
    
    Manually tested, and unittest was added.
    
    Closes #37552 from HyukjinKwon/SPARK-40121.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 336c9bc535895530cc3983b24e7507229fa9570d)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_udf.py                              | 8 +++++++-
 .../org/apache/spark/sql/execution/python/EvalPythonExec.scala    | 1 +
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 40deac992c4..34ac08cb818 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
 
 from pyspark import SparkContext, SQLContext
 from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf, assert_true, lit
+from pyspark.sql.functions import udf, assert_true, lit, rand
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import (
     StringType,
@@ -798,6 +798,12 @@ class UDFTests(ReusedSQLTestCase):
         finally:
             shutil.rmtree(path)
 
+    def test_udf_with_rand(self):
+        # SPARK-40121: rand() with Python UDF.
+        self.assertEqual(
+            len(self.spark.range(10).select(udf(lambda x: x, 
DoubleType())(rand())).collect()), 10
+        )
+
 
 class UDFInitializationTests(unittest.TestCase):
     def tearDown(self):
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index c567a70e1d3..f117a408566 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -116,6 +116,7 @@ trait EvalPythonExec extends UnaryExecNode {
         }.toArray
       }.toArray
       val projection = MutableProjection.create(allInputs.toSeq, child.output)
+      projection.initialize(context.partitionId())
       val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
         StructField(s"_$i", dt)
       }.toSeq)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to