This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fe361e3a1b [SPARK-42115][SQL] Push down limit through Python UDFs
0fe361e3a1b is described below

commit 0fe361e3a1b5be04114402b78e62dd010703477b
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Thu Feb 2 09:00:51 2023 +0800

    [SPARK-42115][SQL] Push down limit through Python UDFs
    
    ### What changes were proposed in this pull request?
    
    This PR adds cases in LimitPushDown to push limits through Python UDFs. In 
order to allow for this, we need to call LimitPushDown in SparkOptimizer after 
the "Extract Python UDFs" batch. We also add PushProjectionThroughLimit 
afterwards in order to plan CollectLimit.
    
    ### Why are the changes needed?
    
    Right now, LimitPushdown does not push limits through Python UDFs, which 
means that expensive Python UDFs can be run on potentially large amounts of 
input. This PR adds this capability, while ensuring that a GlobalLimit - 
LocalLimit pattern stays at the top in order to trigger the CollectLimit code 
path.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added a UT.
    
    Closes #39842 from kelvinjian-db/SPARK-42115-limit-through-python-udfs.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/tests/pandas/test_pandas_udf_scalar.py     |  1 +
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  9 +++++
 .../plans/logical/pythonLogicalOperators.scala     |  3 ++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 +
 .../spark/sql/execution/SparkOptimizer.scala       |  2 +
 .../execution/python/ExtractPythonUDFsSuite.scala  | 43 ++++++++++++++++++++++
 6 files changed, 59 insertions(+)

diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index cbb26e45d2f..33c957fac58 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -981,6 +981,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             with self.assertRaisesRegex(Exception, "reached finally block"):
                 self.spark.range(1).select(test_close(col("id"))).collect()
 
+    @unittest.skip("LimitPushDown should push limits through Python UDFs so 
this won't occur")
     def test_scalar_iter_udf_close_early(self):
         tmp_dir = tempfile.mkdtemp()
         try:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 627e3952480..1233f2207f5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -766,6 +766,15 @@ object LimitPushDown extends Rule[LogicalPlan] {
     // Push down local limit 1 if join type is LeftSemiOrAnti and join 
condition is empty.
     case j @ Join(_, right, LeftSemiOrAnti(_), None, _) if 
!right.maxRows.exists(_ <= 1) =>
       j.copy(right = maybePushLocalLimit(Literal(1, IntegerType), right))
+    // Push down limits through Python UDFs.
+    case LocalLimit(le, udf: BatchEvalPython) =>
+      LocalLimit(le, udf.copy(child = maybePushLocalLimit(le, udf.child)))
+    case LocalLimit(le, p @ Project(_, udf: BatchEvalPython)) =>
+      LocalLimit(le, p.copy(child = udf.copy(child = maybePushLocalLimit(le, 
udf.child))))
+    case LocalLimit(le, udf: ArrowEvalPython) =>
+      LocalLimit(le, udf.copy(child = maybePushLocalLimit(le, udf.child)))
+    case LocalLimit(le, p @ Project(_, udf: ArrowEvalPython)) =>
+      LocalLimit(le, p.copy(child = udf.copy(child = maybePushLocalLimit(le, 
udf.child))))
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index e97ff7808f1..1ce6808be60 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, PythonUDF}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
 import org.apache.spark.sql.types.StructType
@@ -141,6 +142,8 @@ trait BaseEvalPython extends UnaryNode {
   override def output: Seq[Attribute] = child.output ++ resultAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDF)
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 3957ad4af2d..48db1a4408d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -103,6 +103,7 @@ object TreePattern extends Enumeration  {
   val COMMAND: Value = Value
   val CTE: Value = Value
   val DISTINCT_LIKE: Value = Value
+  val EVAL_PYTHON_UDF: Value = Value
   val EVENT_TIME_WATERMARK: Value = Value
   val EXCEPT: Value = Value
   val FILTER: Value = Value
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index c61fd9ce10f..06e3888e7de 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -79,7 +79,9 @@ class SparkOptimizer(
       // The eval-python node may be between Project/Filter and the scan node, 
which breaks
       // column pruning and filter push-down. Here we rerun the related 
optimizer rules.
       ColumnPruning,
+      LimitPushDown,
       PushPredicateThroughNonJoin,
+      PushProjectionThroughLimit,
       RemoveNoopOperators) :+
     Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*) :+
     Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
index 8519357dab0..0ab8691801d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.python
 
+import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, 
BatchEvalPython, Limit, LocalLimit}
 import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, 
SparkPlanTest}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -193,5 +194,47 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with 
SharedSparkSession {
     assert(pythonEvalNodes5.size == 1)
     assert(pythonEvalNodes5.head.udfs.size == 2)
   }
+
+  test("Infers LocalLimit for Python evaluator") {
+    val df = Seq(("Hello", 4), ("World", 8)).toDF("a", "b")
+
+    // Check that PushProjectionThroughLimit brings GlobalLimit - LocalLimit 
to the top (for
+    // CollectLimit) and that LimitPushDown keeps LocalLimit under UDF.
+    val df2 = df.limit(1).select(batchedPythonUDF(col("b")))
+    assert(df2.queryExecution.optimizedPlan match {
+      case Limit(_, _) => true
+    })
+    assert(df2.queryExecution.optimizedPlan.find {
+      case b: BatchEvalPython => b.child.isInstanceOf[LocalLimit]
+      case _ => false
+    }.isDefined)
+
+    val df3 = df.limit(1).select(scalarPandasUDF(col("b")))
+    assert(df3.queryExecution.optimizedPlan match {
+      case Limit(_, _) => true
+    })
+    assert(df3.queryExecution.optimizedPlan.find {
+      case a: ArrowEvalPython => a.child.isInstanceOf[LocalLimit]
+      case _ => false
+    }.isDefined)
+
+    val df4 = df.limit(1).select(batchedPythonUDF(col("b")), 
scalarPandasUDF(col("b")))
+    assert(df4.queryExecution.optimizedPlan match {
+      case Limit(_, _) => true
+    })
+    val evalsWithLimit = df4.queryExecution.optimizedPlan.collect {
+      case b: BatchEvalPython if b.child.isInstanceOf[LocalLimit] => b
+      case a: ArrowEvalPython if a.child.isInstanceOf[LocalLimit] => a
+    }
+    assert(evalsWithLimit.length == 2)
+
+    // Check that LimitPushDown properly pushes LocalLimit past EvalPython 
operators.
+    val df5 = df.select(batchedPythonUDF(col("b")), 
scalarPandasUDF(col("b"))).limit(1)
+    df5.queryExecution.optimizedPlan.foreach {
+      case b: BatchEvalPython => assert(b.child.isInstanceOf[LocalLimit])
+      case a: ArrowEvalPython => assert(a.child.isInstanceOf[LocalLimit])
+      case _ =>
+    }
+  }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to