This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fb3e16a48d [SPARK-43966][SQL][PYTHON] Support non-deterministic 
table-valued functions
1fb3e16a48d is described below

commit 1fb3e16a48d826aed1ca9688a661281f750bbf5a
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Jul 21 12:21:14 2023 +0800

    [SPARK-43966][SQL][PYTHON] Support non-deterministic table-valued functions
    
    ### What changes were proposed in this pull request?
    
    This PR supports non-deterministic table-valued functions. More 
specifically, it supports running non-deterministic Python UDTFs and built-in 
table-valued generator functions with non-deterministic input values.
    
    ### Why are the changes needed?
    
    To make table-valued functions more versatile.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Before this PR, Spark will throw an exception when running a 
non-deterministic Python UDTF:
    ```
    select * from random_udtf(1)
    AnalysisException: [INVALID_NON_DETERMINISTIC_EXPRESSIONS] The operator 
expects a deterministic expression,
    ```
    
    After this PR, it is supported.
    
    ### How was this patch tested?
    
    Existing and new unit tests.
    
    Closes #42075 from allisonwang-db/spark-43966-non-det-udtf.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 python/pyspark/sql/tests/test_udtf.py                  | 18 ++++++------------
 .../spark/sql/catalyst/analysis/CheckAnalysis.scala    |  1 +
 .../analyzer-results/table-valued-functions.sql.out    |  6 ++++++
 .../sql-tests/inputs/table-valued-functions.sql        |  3 +++
 .../sql-tests/results/table-valued-functions.sql.out   |  8 ++++++++
 5 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 3a4c021e990..13ea86ebcb2 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -376,14 +376,12 @@ class BaseUDTFTestsMixin:
 
         class RandomUDTF:
             def eval(self, a: int):
-                yield a * int(random.random() * 100),
+                yield a + int(random.random()),
 
         random_udtf = udtf(RandomUDTF, returnType="x: 
int").asNondeterministic()
-        # TODO(SPARK-43966): support non-deterministic UDTFs
-        with self.assertRaisesRegex(
-            AnalysisException, "The operator expects a deterministic 
expression"
-        ):
-            random_udtf(lit(1)).collect()
+        assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)])
+        self.spark.udtf.register("random_udtf", random_udtf)
+        assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), 
[Row(x=1)])
 
     def test_udtf_with_nondeterministic_input(self):
         from pyspark.sql.functions import rand
@@ -391,13 +389,9 @@ class BaseUDTFTestsMixin:
         @udtf(returnType="x: int")
         class TestUDTF:
             def eval(self, a: int):
-                yield a + 1,
+                yield 1 if a > 100 else 0,
 
-        # TODO(SPARK-43966): support non-deterministic UDTFs
-        with self.assertRaisesRegex(
-            AnalysisException, " The operator expects a deterministic 
expression"
-        ):
-            TestUDTF(rand(0) * 100).collect()
+        assertDataFrameEqual(TestUDTF(rand(0) * 100), [Row(x=0)])
 
     def test_udtf_with_invalid_return_type(self):
         @udtf(returnType="int")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fca1b780088..8b04c8108bd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -752,6 +752,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
             !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
             !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
             !o.isInstanceOf[Expand] &&
+            !o.isInstanceOf[Generate] &&
             // Lateral join is checked in checkSubqueryExpression.
             !o.isInstanceOf[LateralJoin] =>
             // The rule above is used to check Aggregate operator.
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
index 49ad4bf19f7..6c29a0ec1db 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out
@@ -205,6 +205,12 @@ Project [k#x, v#x]
          +- OneRowRelation
 
 
+-- !query
+select * from explode(array(rand(0)))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
 -- !query
 select * from explode(null)
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
index 2b809f9a7c8..79d427bc209 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -43,6 +43,9 @@ select * from explode(map());
 select * from explode(array(1, 2)) t(c1);
 select * from explode(map('a', 1, 'b', 2)) t(k, v);
 
+-- explode with non-deterministic values
+select * from explode(array(rand(0)));
+
 -- explode with erroneous input
 select * from explode(null);
 select * from explode(null) t(c1);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
index 578461d164a..1348110a83a 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -242,6 +242,14 @@ a  1
 b      2
 
 
+-- !query
+select * from explode(array(rand(0)))
+-- !query schema
+struct<col:double>
+-- !query output
+0.7604953758285915
+
+
 -- !query
 select * from explode(null)
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to